Skip to content

Commit

Permalink
Add rewrites for measurable negation and subtraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Sep 12, 2022
1 parent 022812b commit d0d8842
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 8 deletions.
69 changes: 67 additions & 2 deletions aeppl/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from aesara.scalar import Add, Exp, Log, Mul, Reciprocal, TrueDiv
from aesara.scalar import Add, Exp, Log, Mul, Neg, Reciprocal, Sub, TrueDiv
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.rewriting.basic import (
Expand Down Expand Up @@ -292,6 +292,56 @@ def measurable_div_to_reciprocal_product(fgraph, node):
return [at.mul(numerator, at.reciprocal(denominator))]


@node_rewriter([Elemwise])
def measurable_neg_to_product(fgraph, node):
"""Convert negation of `MeasurableVariable`s to product with `-1`."""
if isinstance(node.op.scalar_op, Neg):
inp = node.inputs[0]
if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)):
return None

rv_map_feature: Optional[PreserveRVMappings] = getattr(
fgraph, "preserve_rv_mappings", None
)
if rv_map_feature is None:
return None # pragma: no cover

# Only apply this rewrite if the variable is unvalued
if inp in rv_map_feature.rv_values:
return None # pragma: no cover

return [at.mul(inp, -1.0)]


@node_rewriter([Elemwise])
def measurable_sub_to_neg(fgraph, node):
"""Convert subtraction involving `MeasurableVariable`s to addition with neg"""
if isinstance(node.op.scalar_op, Sub):
measurable_vars = [
var
for var in node.inputs
if (var.owner and isinstance(var.owner.op, MeasurableVariable))
]
if not measurable_vars:
return None # pragma: no cover

rv_map_feature: Optional[PreserveRVMappings] = getattr(
fgraph, "preserve_rv_mappings", None
)
if rv_map_feature is None:
return None # pragma: no cover

# Only apply this rewrite if there is one unvalued MeasurableVariable involved
if all(
measurable_var in rv_map_feature.rv_values
for measurable_var in measurable_vars
):
return None # pragma: no cover

minuend, subtrahend = node.inputs
return [at.add(minuend, at.neg(subtrahend))]


@node_rewriter([Elemwise])
def find_measurable_transforms(
fgraph: FunctionGraph, node: Node
Expand Down Expand Up @@ -383,11 +433,26 @@ def find_measurable_transforms(
measurable_ir_rewrites_db.register(
"measurable_div_to_reciprocal_product",
measurable_div_to_reciprocal_product,
-1,
-5,
"basic",
"transform",
)

measurable_ir_rewrites_db.register(
"measurable_neg_to_product",
measurable_neg_to_product,
-5,
"basic",
"transform",
)

measurable_ir_rewrites_db.register(
"measurable_sub_to_neg",
measurable_sub_to_neg,
-5,
"basic",
"transform",
)

measurable_ir_rewrites_db.register(
"find_measurable_transforms",
Expand Down
36 changes: 30 additions & 6 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,17 +591,20 @@ def test_log_transform_rv():


@pytest.mark.parametrize(
"rv_size, loc_type",
"rv_size, loc_type, addition",
[
(None, at.scalar),
(2, at.vector),
((2, 1), at.col),
(None, at.scalar, True),
(2, at.vector, False),
((2, 1), at.col, True),
],
)
def test_loc_transform_rv(rv_size, loc_type):
def test_loc_transform_rv(rv_size, loc_type, addition):

loc = loc_type("loc")
y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv")
if addition:
y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv")
else:
y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") - at.neg(loc)
y_rv.name = "y"
y_vv = y_rv.clone()

Expand Down Expand Up @@ -731,3 +734,24 @@ def test_reciprocal_rv_transform(numerator):
x_logp_fn(x_test_val),
sp.stats.invgamma(shape, scale=scale * numerator).logpdf(x_test_val),
)


def test_negated_rv_transform():
x_rv = -at.random.halfnormal()
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = aesara.function([x_vv], joint_logprob({x_rv: x_vv}))

assert np.isclose(x_logp_fn(-1.5), sp.stats.halfnorm.logpdf(1.5))


def test_subtracted_rv_transform():
# Choose base RV that is assymetric around zero
x_rv = 5.0 - at.random.normal(1.0)
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = aesara.function([x_vv], joint_logprob({x_rv: x_vv}))

assert np.isclose(x_logp_fn(7.3), sp.stats.norm.logpdf(5.0 - 7.3, 1.0))

0 comments on commit d0d8842

Please sign in to comment.