Skip to content

Commit

Permalink
Add rewrites for measurable negation and subtraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed May 2, 2022
1 parent ed8a342 commit 7ffdcd1
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 7 deletions.
38 changes: 37 additions & 1 deletion aeppl/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from aesara.graph.op import compute_test_value
from aesara.graph.opt import local_optimizer
from aesara.graph.optdb import EquilibriumDB, OptimizationQuery, SequenceDB
from aesara.scalar import TrueDiv
from aesara.scalar import Neg, Sub, TrueDiv
from aesara.tensor.basic_opt import (
ShapeFeature,
register_canonicalize,
Expand Down Expand Up @@ -318,6 +318,34 @@ def measurable_div_to_reciprocal_product(fgraph, node):
return [at.mul(numerator, at.reciprocal(denominator))]


@local_optimizer([Elemwise])
def measurable_neg_to_product(fgraph, node):
"""Convert negation of `MeasurableVariable`s to product with `-1`."""
if isinstance(node.op.scalar_op, Neg):
inp = node.inputs[0]
if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)):
return None

return [at.mul(inp, -1.0)]


@local_optimizer([Elemwise])
def measurable_sub_to_neg(fgraph, node):
"""Convert subtraction involving `MeasurableVariable`s to addition with neg"""
if isinstance(node.op.scalar_op, Sub):
minuend, subtrahend = node.inputs

if not (
(minuend.owner and isinstance(minuend.owner.op, MeasurableVariable))
or (
subtrahend.owner and isinstance(subtrahend.owner.op, MeasurableVariable)
)
):
return None

return [at.add(minuend, at.neg(subtrahend))]


logprob_rewrites_db = SequenceDB()
logprob_rewrites_db.name = "logprob_rewrites_db"
logprob_rewrites_db.register(
Expand Down Expand Up @@ -357,6 +385,14 @@ def measurable_div_to_reciprocal_product(fgraph, node):
"basic",
)

measurable_ir_rewrites_db.register(
"measurable_neg_to_product", measurable_neg_to_product, -5, "basic"
)

measurable_ir_rewrites_db.register(
"measurable_sub_to_neg", measurable_sub_to_neg, -5, "basic"
)

logprob_rewrites_db.register(
"post-canonicalize", optdb.query("+canonicalize"), 10, "basic"
)
Expand Down
36 changes: 30 additions & 6 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,17 +589,20 @@ def test_log_transform_rv():


@pytest.mark.parametrize(
"rv_size, loc_type",
"rv_size, loc_type, addition",
[
(None, at.scalar),
(2, at.vector),
((2, 1), at.col),
(None, at.scalar, True),
(2, at.vector, False),
((2, 1), at.col, True),
],
)
def test_loc_transform_rv(rv_size, loc_type):
def test_loc_transform_rv(rv_size, loc_type, addition):

loc = loc_type("loc")
y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv")
if addition:
y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv")
else:
y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") - at.neg(loc)
y_rv.name = "y"
y_vv = y_rv.clone()

Expand Down Expand Up @@ -725,3 +728,24 @@ def test_inverse_rv_transform(numerator):
x_logp_fn(x_test_val),
sp.stats.invgamma(shape, scale=scale * numerator).logpdf(x_test_val),
)


def test_negated_rv_transform():
x_rv = -at.random.halfnormal()
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = aesara.function([x_vv], joint_logprob({x_rv: x_vv}))

assert np.isclose(x_logp_fn(-1.5), sp.stats.halfnorm.logpdf(1.5))


def test_subtracted_rv_transform():
# Choose base RV that is assymetric around zero
x_rv = 5.0 - at.random.normal(1.0)
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = aesara.function([x_vv], joint_logprob({x_rv: x_vv}))

assert np.isclose(x_logp_fn(7.3), sp.stats.norm.logpdf(5.0 - 7.3, 1.0))

0 comments on commit 7ffdcd1

Please sign in to comment.