Skip to content

Commit

Permalink
Implement Inverse measurable transform
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Apr 29, 2022
1 parent 8d21e4a commit dcfaa14
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 2 deletions.
18 changes: 18 additions & 0 deletions aeppl/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from aesara.graph.op import compute_test_value
from aesara.graph.opt import local_optimizer
from aesara.graph.optdb import EquilibriumDB, OptimizationQuery, SequenceDB
from aesara.scalar import TrueDiv
from aesara.tensor.basic_opt import (
ShapeFeature,
register_canonicalize,
register_useless,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.extra_ops import BroadcastTo
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.opt import (
Expand Down Expand Up @@ -292,6 +294,20 @@ def naive_bcast_rv_lift(fgraph, node):
return [bcasted_node.outputs[1]]


@local_optimizer([Elemwise])
def div_to_reciprocal(fgraph, node):
"""Convert 5/3 to 5 * 1/3"""
if isinstance(node.op.scalar_op, TrueDiv):
numerator, denominator = node.inputs
# Check if numerator is 1
try:
if at.get_scalar_constant_value(numerator) == 1:
return [at.reciprocal(denominator)]
except NotScalarConstantError:
pass
return [at.mul(numerator, at.reciprocal(denominator))]


logprob_rewrites_db = SequenceDB()
logprob_rewrites_db.name = "logprob_rewrites_db"
logprob_rewrites_db.register(
Expand Down Expand Up @@ -324,6 +340,8 @@ def naive_bcast_rv_lift(fgraph, node):
"incsubtensor_lift", incsubtensor_rv_replace, -5, "basic"
)

measurable_ir_rewrites_db.register("div_to_reciprocal", div_to_reciprocal, -5, "basic")

logprob_rewrites_db.register(
"post-canonicalize", optdb.query("+canonicalize"), 10, "basic"
)
Expand Down
19 changes: 17 additions & 2 deletions aeppl/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.scalar import Add, Exp, Log, Mul
from aesara.scalar import Add, Exp, Log, Mul, Reciprocal
from aesara.tensor.basic_opt import (
register_specialize,
register_stabilize,
Expand Down Expand Up @@ -261,7 +261,7 @@ def find_measurable_transforms(
) -> Optional[List[Node]]:
"""Find measurable transformations from Elemwise operators."""
scalar_op = node.op.scalar_op
if not isinstance(scalar_op, (Exp, Log, Add, Mul)):
if not isinstance(scalar_op, (Exp, Log, Reciprocal, Add, Mul)):
return None

# Node was already converted
Expand Down Expand Up @@ -319,6 +319,8 @@ def find_measurable_transforms(
transform = ExpTransform()
elif isinstance(scalar_op, Log):
transform = LogTransform()
elif isinstance(scalar_op, Reciprocal):
transform = InverseTransform()
elif isinstance(scalar_op, Add):
transform_inputs = (measurable_input, at.add(*other_inputs))
transform = LocTransform(
Expand Down Expand Up @@ -413,6 +415,19 @@ def log_jac_det(self, value, *inputs):
return -at.log(value)


class InverseTransform(RVTransform):
name = "inverse"

def forward(self, value, *inputs):
return at.inv(value)

def backward(self, value, *inputs):
return at.inv(value)

def log_jac_det(self, value, *inputs):
return -at.log(value**2)


class IntervalTransform(RVTransform):
name = "interval"

Expand Down
17 changes: 17 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,3 +703,20 @@ def test_invalid_broadcasted_transform_rv_fails():
logp = joint_logprob({y_rv: y_vv})
logp.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]})
assert False, "Should have failed before"


@pytest.mark.parametrize("numerator", (1.0, 2.0))
def test_inverse_rv_transform(numerator):
shape = 3
scale = 5
x_rv = numerator / at.random.gamma(shape, scale)
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = aesara.function([x_vv], joint_logprob({x_rv: x_vv}))

x_test_val = 1.5
assert np.isclose(
x_logp_fn(x_test_val),
sp.stats.invgamma(shape, scale=scale * numerator).logpdf(x_test_val),
)

0 comments on commit dcfaa14

Please sign in to comment.