diff --git a/aeppl/transforms.py b/aeppl/transforms.py index b76b7733..85637152 100644 --- a/aeppl/transforms.py +++ b/aeppl/transforms.py @@ -10,7 +10,7 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter -from aesara.scalar import Add, Exp, Log, Mul, Reciprocal, TrueDiv +from aesara.scalar import Add, Exp, Log, Mul, Neg, Reciprocal, Sub, TrueDiv from aesara.tensor.elemwise import Elemwise from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.rewriting.basic import ( @@ -292,6 +292,56 @@ def measurable_div_to_reciprocal_product(fgraph, node): return [at.mul(numerator, at.reciprocal(denominator))] +@node_rewriter([Elemwise]) +def measurable_neg_to_product(fgraph, node): + """Convert negation of `MeasurableVariable`s to product with `-1`.""" + if isinstance(node.op.scalar_op, Neg): + inp = node.inputs[0] + if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)): + return None + + rv_map_feature: Optional[PreserveRVMappings] = getattr( + fgraph, "preserve_rv_mappings", None + ) + if rv_map_feature is None: + return None # pragma: no cover + + # Only apply this rewrite if the variable is unvalued + if inp in rv_map_feature.rv_values: + return None # pragma: no cover + + return [at.mul(inp, -1.0)] + + +@node_rewriter([Elemwise]) +def measurable_sub_to_neg(fgraph, node): + """Convert subtraction involving `MeasurableVariable`s to addition with neg""" + if isinstance(node.op.scalar_op, Sub): + measurable_vars = [ + var + for var in node.inputs + if (var.owner and isinstance(var.owner.op, MeasurableVariable)) + ] + if not measurable_vars: + return None # pragma: no cover + + rv_map_feature: Optional[PreserveRVMappings] = getattr( + fgraph, "preserve_rv_mappings", None + ) + if rv_map_feature is None: + return None # pragma: no cover + + # Only apply this rewrite if there is one unvalued MeasurableVariable involved + if all( + measurable_var in rv_map_feature.rv_values + for measurable_var in measurable_vars + ): + return None # pragma: no cover + + minuend, subtrahend = node.inputs + return [at.add(minuend, at.neg(subtrahend))] + + @node_rewriter([Elemwise]) def find_measurable_transforms( fgraph: FunctionGraph, node: Node @@ -383,11 +433,26 @@ def find_measurable_transforms( measurable_ir_rewrites_db.register( "measurable_div_to_reciprocal_product", measurable_div_to_reciprocal_product, - -1, + -5, "basic", "transform", ) +measurable_ir_rewrites_db.register( + "measurable_neg_to_product", + measurable_neg_to_product, + -5, + "basic", + "transform", +) + +measurable_ir_rewrites_db.register( + "measurable_sub_to_neg", + measurable_sub_to_neg, + -5, + "basic", + "transform", +) measurable_ir_rewrites_db.register( "find_measurable_transforms", diff --git a/tests/test_transforms.py b/tests/test_transforms.py index afc78d39..fb354f3d 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -591,17 +591,20 @@ def test_log_transform_rv(): @pytest.mark.parametrize( - "rv_size, loc_type", + "rv_size, loc_type, addition", [ - (None, at.scalar), - (2, at.vector), - ((2, 1), at.col), + (None, at.scalar, True), + (2, at.vector, False), + ((2, 1), at.col, True), ], ) -def test_loc_transform_rv(rv_size, loc_type): +def test_loc_transform_rv(rv_size, loc_type, addition): loc = loc_type("loc") - y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv") + if addition: + y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv") + else: + y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") - at.neg(loc) y_rv.name = "y" y_vv = y_rv.clone() @@ -731,3 +734,24 @@ def test_reciprocal_rv_transform(numerator): x_logp_fn(x_test_val), sp.stats.invgamma(shape, scale=scale * numerator).logpdf(x_test_val), ) + + +def test_negated_rv_transform(): + x_rv = -at.random.halfnormal() + x_rv.name = "x" + + x_vv = x_rv.clone() + x_logp_fn = aesara.function([x_vv], joint_logprob({x_rv: x_vv})) + + assert np.isclose(x_logp_fn(-1.5), sp.stats.halfnorm.logpdf(1.5)) + + +def test_subtracted_rv_transform(): + # Choose base RV that is assymetric around zero + x_rv = 5.0 - at.random.normal(1.0) + x_rv.name = "x" + + x_vv = x_rv.clone() + x_logp_fn = aesara.function([x_vv], joint_logprob({x_rv: x_vv})) + + assert np.isclose(x_logp_fn(7.3), sp.stats.norm.logpdf(5.0 - 7.3, 1.0))