From 7ffdcd1637b8bd20826e9fde94c9923cecb790d6 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Mon, 2 May 2022 08:46:54 +0200 Subject: [PATCH] Add rewrites for measurable negation and subtraction --- aeppl/opt.py | 38 +++++++++++++++++++++++++++++++++++++- tests/test_transforms.py | 36 ++++++++++++++++++++++++++++++------ 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/aeppl/opt.py b/aeppl/opt.py index d9c851b2..9fbe3931 100644 --- a/aeppl/opt.py +++ b/aeppl/opt.py @@ -9,7 +9,7 @@ from aesara.graph.op import compute_test_value from aesara.graph.opt import local_optimizer from aesara.graph.optdb import EquilibriumDB, OptimizationQuery, SequenceDB -from aesara.scalar import TrueDiv +from aesara.scalar import Neg, Sub, TrueDiv from aesara.tensor.basic_opt import ( ShapeFeature, register_canonicalize, @@ -318,6 +318,34 @@ def measurable_div_to_reciprocal_product(fgraph, node): return [at.mul(numerator, at.reciprocal(denominator))] +@local_optimizer([Elemwise]) +def measurable_neg_to_product(fgraph, node): + """Convert negation of `MeasurableVariable`s to product with `-1`.""" + if isinstance(node.op.scalar_op, Neg): + inp = node.inputs[0] + if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)): + return None + + return [at.mul(inp, -1.0)] + + +@local_optimizer([Elemwise]) +def measurable_sub_to_neg(fgraph, node): + """Convert subtraction involving `MeasurableVariable`s to addition with neg""" + if isinstance(node.op.scalar_op, Sub): + minuend, subtrahend = node.inputs + + if not ( + (minuend.owner and isinstance(minuend.owner.op, MeasurableVariable)) + or ( + subtrahend.owner and isinstance(subtrahend.owner.op, MeasurableVariable) + ) + ): + return None + + return [at.add(minuend, at.neg(subtrahend))] + + logprob_rewrites_db = SequenceDB() logprob_rewrites_db.name = "logprob_rewrites_db" logprob_rewrites_db.register( @@ -357,6 +385,14 @@ def measurable_div_to_reciprocal_product(fgraph, node): "basic", ) +measurable_ir_rewrites_db.register( + "measurable_neg_to_product", measurable_neg_to_product, -5, "basic" +) + +measurable_ir_rewrites_db.register( + "measurable_sub_to_neg", measurable_sub_to_neg, -5, "basic" +) + logprob_rewrites_db.register( "post-canonicalize", optdb.query("+canonicalize"), 10, "basic" ) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index cc48122f..24c20193 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -589,17 +589,20 @@ def test_log_transform_rv(): @pytest.mark.parametrize( - "rv_size, loc_type", + "rv_size, loc_type, addition", [ - (None, at.scalar), - (2, at.vector), - ((2, 1), at.col), + (None, at.scalar, True), + (2, at.vector, False), + ((2, 1), at.col, True), ], ) -def test_loc_transform_rv(rv_size, loc_type): +def test_loc_transform_rv(rv_size, loc_type, addition): loc = loc_type("loc") - y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv") + if addition: + y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv") + else: + y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") - at.neg(loc) y_rv.name = "y" y_vv = y_rv.clone() @@ -725,3 +728,24 @@ def test_inverse_rv_transform(numerator): x_logp_fn(x_test_val), sp.stats.invgamma(shape, scale=scale * numerator).logpdf(x_test_val), ) + + +def test_negated_rv_transform(): + x_rv = -at.random.halfnormal() + x_rv.name = "x" + + x_vv = x_rv.clone() + x_logp_fn = aesara.function([x_vv], joint_logprob({x_rv: x_vv})) + + assert np.isclose(x_logp_fn(-1.5), sp.stats.halfnorm.logpdf(1.5)) + + +def test_subtracted_rv_transform(): + # Choose base RV that is assymetric around zero + x_rv = 5.0 - at.random.normal(1.0) + x_rv.name = "x" + + x_vv = x_rv.clone() + x_logp_fn = aesara.function([x_vv], joint_logprob({x_rv: x_vv})) + + assert np.isclose(x_logp_fn(7.3), sp.stats.norm.logpdf(5.0 - 7.3, 1.0))