diff --git a/aeppl/abstract.py b/aeppl/abstract.py index d8ccb2f0..312b5745 100644 --- a/aeppl/abstract.py +++ b/aeppl/abstract.py @@ -1,7 +1,7 @@ import abc from copy import copy from functools import singledispatch -from typing import Callable, List, Tuple +from typing import Callable, List from aesara.graph.basic import Apply, Variable from aesara.graph.op import Op @@ -122,15 +122,5 @@ def assign_custom_measurable_outputs( class MeasurableElemwise(Elemwise): """Base class for Measurable Elemwise variables""" - valid_scalar_types: Tuple[MetaType, ...] = () - - def __init__(self, scalar_op, *args, **kwargs): - if not isinstance(scalar_op, self.valid_scalar_types): - raise TypeError( - f"scalar_op {scalar_op} is not valid for class {self.__class__}. " - f"Acceptable types are {self.valid_scalar_types}" - ) - super().__init__(scalar_op, *args, **kwargs) - MeasurableVariable.register(MeasurableElemwise) diff --git a/aeppl/censoring.py b/aeppl/censoring.py index 4d46a83f..d12ad484 100644 --- a/aeppl/censoring.py +++ b/aeppl/censoring.py @@ -1,13 +1,15 @@ -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional import aesara.tensor as at import numpy as np from aesara.graph.basic import Node from aesara.graph.fg import FunctionGraph from aesara.graph.rewriting.basic import node_rewriter -from aesara.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven +from aesara.scalar.basic import ceil as scalar_ceil from aesara.scalar.basic import clip as scalar_clip -from aesara.tensor.elemwise import Elemwise +from aesara.scalar.basic import floor as scalar_floor +from aesara.scalar.basic import round_half_to_even as scalar_round_half_to_even +from aesara.tensor.math import ceil, clip, floor, round_half_to_even from aesara.tensor.var import TensorConstant from aeppl.abstract import ( @@ -18,32 +20,27 @@ from aeppl.logprob import CheckParameterValue, _logcdf, _logprob, logdiffexp from aeppl.rewriting import measurable_ir_rewrites_db +if TYPE_CHECKING: + from aesara.graph.basic import Op, Variable + class MeasurableClip(MeasurableElemwise): """A placeholder used to specify a log-likelihood for a clipped RV sub-graph.""" - valid_scalar_types = (Clip,) - measurable_clip = MeasurableClip(scalar_clip) -@node_rewriter(tracks=[Elemwise]) +@node_rewriter([clip]) def find_measurable_clips( fgraph: FunctionGraph, node: Node -) -> Optional[List[MeasurableClip]]: +) -> Optional[List["Variable"]]: # TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub) rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover - if isinstance(node.op, MeasurableClip): - return None # pragma: no cover - - if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Clip)): - return None - clipped_var = node.outputs[0] base_var, lower_bound, upper_bound = node.inputs @@ -75,7 +72,6 @@ def find_measurable_clips( measurable_ir_rewrites_db.register( "find_measurable_clips", find_measurable_clips, - 0, "basic", "censoring", ) @@ -147,27 +143,55 @@ def clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs): class MeasurableRound(MeasurableElemwise): """A placeholder used to specify a log-likelihood for a clipped RV sub-graph.""" - valid_scalar_types = (RoundHalfToEven, Floor, Ceil) +measurable_ceil = MeasurableRound(scalar_ceil) +measurable_floor = MeasurableRound(scalar_floor) +measurable_round_half_to_even = MeasurableRound(scalar_round_half_to_even) -@node_rewriter(tracks=[Elemwise]) -def find_measurable_roundings( - fgraph: FunctionGraph, node: Node -) -> Optional[List[MeasurableRound]]: + +@node_rewriter([ceil]) +def find_measurable_ceil(fgraph: FunctionGraph, node: Node): + return construct_measurable_rounding(fgraph, node, measurable_ceil) + + +@node_rewriter([floor]) +def find_measurable_floor(fgraph: FunctionGraph, node: Node): + return construct_measurable_rounding(fgraph, node, measurable_floor) + + +@node_rewriter([round_half_to_even]) +def find_measurable_round_half_to_even(fgraph: FunctionGraph, node: Node): + return construct_measurable_rounding(fgraph, node, measurable_round_half_to_even) + + +measurable_ir_rewrites_db.register( + "find_measurable_ceil", + find_measurable_ceil, + "basic", + "censoring", +) +measurable_ir_rewrites_db.register( + "find_measurable_floor", + find_measurable_floor, + "basic", + "censoring", +) +measurable_ir_rewrites_db.register( + "find_measurable_round_half_to_even", + find_measurable_round_half_to_even, + "basic", + "censoring", +) + + +def construct_measurable_rounding( + fgraph: FunctionGraph, node: Node, rounded_op: "Op" +) -> Optional[List["Variable"]]: rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover - if isinstance(node.op, MeasurableRound): - return None # pragma: no cover - - if not ( - isinstance(node.op, Elemwise) - and isinstance(node.op.scalar_op, MeasurableRound.valid_scalar_types) - ): - return None - (rounded_var,) = node.outputs (base_var,) = node.inputs @@ -183,21 +207,11 @@ def find_measurable_roundings( # Make base_var unmeasurable unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner) - rounded_op = MeasurableRound(node.op.scalar_op) rounded_rv = rounded_op.make_node(unmeasurable_base_var).default_output() rounded_rv.name = rounded_var.name return [rounded_rv] -measurable_ir_rewrites_db.register( - "find_measurable_roundings", - find_measurable_roundings, - 0, - "basic", - "censoring", -) - - @_logprob.register(MeasurableRound) def round_logprob(op, values, base_rv, **kwargs): r"""Logprob of a rounded censored distribution @@ -226,15 +240,15 @@ def round_logprob(op, values, base_rv, **kwargs): """ (value,) = values - if isinstance(op.scalar_op, RoundHalfToEven): + if op == measurable_round_half_to_even: value = at.round(value) value_upper = value + 0.5 value_lower = value - 0.5 - elif isinstance(op.scalar_op, Floor): + elif op == measurable_floor: value = at.floor(value) value_upper = value + 1.0 value_lower = value - elif isinstance(op.scalar_op, Ceil): + elif op == measurable_ceil: value = at.ceil(value) value_upper = value value_lower = value - 1.0 diff --git a/aeppl/cumsum.py b/aeppl/cumsum.py index ded2f4c1..f66fc3e9 100644 --- a/aeppl/cumsum.py +++ b/aeppl/cumsum.py @@ -83,7 +83,6 @@ def find_measurable_cumsums(fgraph, node) -> Optional[List[MeasurableCumsum]]: measurable_ir_rewrites_db.register( "find_measurable_cumsums", find_measurable_cumsums, - 0, "basic", "cumsum", ) diff --git a/aeppl/joint_logprob.py b/aeppl/joint_logprob.py index 61e61f71..597ddcbc 100644 --- a/aeppl/joint_logprob.py +++ b/aeppl/joint_logprob.py @@ -102,7 +102,7 @@ def conditional_logprob( vv.name = f"{rv.name}_vv" original_rv_values[rv] = vv - # Value variables are not cloned when constructing the conditional log-proprobability + # Value variables are not cloned when constructing the conditional log-probability # graphs. We can thus use them to recover the original random variables to index the # maps to the logprob graphs and value variables before returning them. rv_values = {**original_rv_values, **realized} diff --git a/aeppl/mixture.py b/aeppl/mixture.py index 99bcbd78..457fa976 100644 --- a/aeppl/mixture.py +++ b/aeppl/mixture.py @@ -423,7 +423,6 @@ def logprob_MixtureRV( [mixture_replace, switch_mixture_replace], max_use_ratio=aesara.config.optdb__max_use_ratio, ), - 0, "basic", "mixture", ) diff --git a/aeppl/rewriting.py b/aeppl/rewriting.py index fbbcdce9..7326359f 100644 --- a/aeppl/rewriting.py +++ b/aeppl/rewriting.py @@ -218,9 +218,7 @@ def incsubtensor_rv_replace(fgraph, node): logprob_rewrites_db = SequenceDB() logprob_rewrites_db.name = "logprob_rewrites_db" -logprob_rewrites_db.register( - "pre-canonicalize", optdb.query("+canonicalize"), -10, "basic" -) +logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic") # These rewrites convert un-measurable variables into their measurable forms, # but they need to be reapplied, because some of the measurable forms require @@ -229,22 +227,18 @@ def incsubtensor_rv_replace(fgraph, node): measurable_ir_rewrites_db.name = "measurable_ir_rewrites_db" logprob_rewrites_db.register( - "measurable_ir_rewrites", measurable_ir_rewrites_db, -10, "basic" + "measurable_ir_rewrites", measurable_ir_rewrites_db, "basic" ) # These rewrites push random/measurable variables "down", making them closer to # (or eventually) the graph outputs. Often this is done by lifting other `Op`s # "up" through the random/measurable variables and into their inputs. +measurable_ir_rewrites_db.register("subtensor_lift", local_subtensor_rv_lift, "basic") measurable_ir_rewrites_db.register( - "subtensor_lift", local_subtensor_rv_lift, -5, "basic" -) -measurable_ir_rewrites_db.register( - "incsubtensor_lift", incsubtensor_rv_replace, -5, "basic" + "incsubtensor_lift", incsubtensor_rv_replace, "basic" ) -logprob_rewrites_db.register( - "post-canonicalize", optdb.query("+canonicalize"), 10, "basic" -) +logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic") def construct_ir_fgraph( @@ -326,6 +320,6 @@ def construct_ir_fgraph( new_to_old = tuple( (v, k) for k, v in rv_remapper.measurable_conversions.items() ) - fgraph.replace_all(new_to_old) + fgraph.replace_all(new_to_old, reason="construct_ir_fgraph") return fgraph, rv_values, memo diff --git a/aeppl/scan.py b/aeppl/scan.py index 653b22ef..8b38fd33 100644 --- a/aeppl/scan.py +++ b/aeppl/scan.py @@ -513,7 +513,6 @@ def _get_measurable_outputs_MeasurableScan(op, node): # out2in( # add_opts_to_inner_graphs, name="add_opts_to_inner_graphs", ignore_newtrees=True # ), - -100, "basic", "scan", ) @@ -521,11 +520,10 @@ def _get_measurable_outputs_MeasurableScan(op, node): measurable_ir_rewrites_db.register( "find_measurable_scans", find_measurable_scans, - 0, "basic", "scan", ) # Add scan canonicalizations that aren't in the canonicalization DB -logprob_rewrites_db.register("scan_eqopt1", scan_eqopt1, -9, "basic", "scan") -logprob_rewrites_db.register("scan_eqopt2", scan_eqopt2, -9, "basic", "scan") +logprob_rewrites_db.register("scan_eqopt1", scan_eqopt1, "basic", "scan") +logprob_rewrites_db.register("scan_eqopt2", scan_eqopt2, "basic", "scan") diff --git a/aeppl/tensor.py b/aeppl/tensor.py index 0fac7f2b..5292b4ef 100644 --- a/aeppl/tensor.py +++ b/aeppl/tensor.py @@ -273,25 +273,24 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[List[MeasurableDimShuf measurable_ir_rewrites_db.register( - "dimshuffle_lift", local_dimshuffle_rv_lift, -5, "basic", "tensor" + "dimshuffle_lift", local_dimshuffle_rv_lift, "basic", "tensor" ) # We register this later than `dimshuffle_lift` so that it is only applied as a fallback measurable_ir_rewrites_db.register( - "find_measurable_dimshuffles", find_measurable_dimshuffles, 0, "basic", "tensor" + "find_measurable_dimshuffles", find_measurable_dimshuffles, "basic", "tensor" ) measurable_ir_rewrites_db.register( - "broadcast_to_lift", naive_bcast_rv_lift, -5, "basic", "tensor" + "broadcast_to_lift", naive_bcast_rv_lift, "basic", "tensor" ) measurable_ir_rewrites_db.register( "find_measurable_stacks", find_measurable_stacks, - 0, "basic", "tensor", ) diff --git a/aeppl/transforms.py b/aeppl/transforms.py index 35082e75..6abe62b6 100644 --- a/aeppl/transforms.py +++ b/aeppl/transforms.py @@ -1,23 +1,23 @@ import abc from copy import copy from functools import partial, singledispatch -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import aesara.tensor as at from aesara.gradient import DisconnectedType, jacobian -from aesara.graph.basic import Apply, Node, Variable +from aesara.graph.basic import Apply, Variable from aesara.graph.features import AlreadyThere, Feature from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter -from aesara.scalar import Add, Exp, Log, Mul -from aesara.tensor.elemwise import Elemwise +from aesara.tensor.math import add, exp, log, mul, reciprocal, sub, true_div from aesara.tensor.rewriting.basic import ( register_specialize, register_stabilize, register_useless, ) from aesara.tensor.var import TensorVariable +from typing_extensions import Protocol from aeppl.abstract import ( MeasurableElemwise, @@ -29,11 +29,32 @@ from aeppl.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from aeppl.utils import walk_model +if TYPE_CHECKING: + from aesara.graph.rewriting.basic import NodeRewriter + + class TransformFnType(Protocol): + def __call__( + self, measurable_input: MeasurableVariable, *other_inputs: Variable + ) -> Tuple["RVTransform", Tuple[TensorVariable, ...]]: + pass + + +def register_measurable_ir( + node_rewriter: "NodeRewriter", + *tags: str, + **kwargs, +): + name = kwargs.pop("name", None) or node_rewriter.__name__ + measurable_ir_rewrites_db.register( + name, node_rewriter, "basic", "transform", *tags, **kwargs + ) + return node_rewriter + @singledispatch def _default_transformed_rv( op: Op, - node: Node, + node: Apply, ) -> Optional[Apply]: """Create a node for a transformed log-probability of a `MeasurableVariable`. @@ -109,7 +130,7 @@ class DefaultTransformSentinel: @node_rewriter(tracks=None) -def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: +def transform_values(fgraph: FunctionGraph, node: Apply) -> Optional[List[Variable]]: """Apply transforms to value variables. It is assumed that the input value variables correspond to forward @@ -245,10 +266,8 @@ def apply(self, fgraph: FunctionGraph): return self.default_transform_rewrite.rewrite(fgraph) -class MeasurableTransform(MeasurableElemwise): - """A placeholder used to specify a log-likelihood for a transformed measurable variable""" - - valid_scalar_types = (Exp, Log, Add, Mul) +class MeasurableElemwiseTransform(MeasurableElemwise): + """A placeholder used to specify a log-likelihood for a transformed `Elemwise`.""" # Cannot use `transform` as name because it would clash with the property added by # the `TransformValuesRewrite` @@ -263,14 +282,16 @@ def __init__( super().__init__(*args, **kwargs) -@_get_measurable_outputs.register(MeasurableTransform) -def _get_measurable_outputs_Transform(op, node): +@_get_measurable_outputs.register(MeasurableElemwiseTransform) +def _get_measurable_outputs_ElemwiseTransform(op, node): return [node.default_output()] -@_logprob.register(MeasurableTransform) -def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwargs): - """Compute the log-probability graph for a `MeasurabeTransform`.""" +@_logprob.register(MeasurableElemwiseTransform) +def measurable_elemwise_logprob( + op: MeasurableElemwiseTransform, values, *inputs, **kwargs +): + """Compute the log-probability graph for a `MeasurableElemwiseTransform`.""" # TODO: Could other rewrites affect the order of inputs? (value,) = values other_inputs = list(inputs) @@ -286,18 +307,151 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa return input_logprob + jacobian -@node_rewriter([Elemwise]) -def find_measurable_transforms( - fgraph: FunctionGraph, node: Node -) -> Optional[List[Node]]: - """Find measurable transformations from Elemwise operators.""" - scalar_op = node.op.scalar_op - if not isinstance(scalar_op, MeasurableTransform.valid_scalar_types): - return None +@register_measurable_ir +@node_rewriter([true_div]) +def measurable_true_div(fgraph, node): + r"""Rewrite a `true_div` node to a `MeasurableVariable`. - # Node was already converted - if isinstance(node.op, MeasurableVariable): - return None # pragma: no cover + TODO FIXME: We need update/clarify the canonicalization situation so that + these can be reliably rewritten as products of reciprocals. + + """ + numerator, denominator = node.inputs + + reciprocal_denominator = at.reciprocal(denominator) + # `denominator` is measurable + res = measurable_reciprocal.transform(fgraph, reciprocal_denominator.owner) + if res: + return measurable_mul.transform(fgraph, at.mul(numerator, res[0]).owner) + + # `numerator` is measurable + return measurable_mul.transform( + fgraph, at.mul(numerator, reciprocal_denominator).owner + ) + + +@register_measurable_ir +@node_rewriter([sub]) +def measurable_sub(fgraph, node): + r"""Rewrite a `sub` node to a `MeasurableVariable`. + + TODO FIXME: We need update/clarify the canonicalization situation so that + these can be reliably rewritten as products of reciprocals. + + """ + minuend, subtrahend = node.inputs + + mul_subtrahend = at.mul(-1, subtrahend) + + # `subtrahend` is measurable + res = measurable_mul.transform(fgraph, mul_subtrahend.owner) + if res: + return measurable_add.transform(fgraph, at.add(minuend, res[0]).owner) + + # TODO FIXME: `local_add_canonizer` will unreliably rewrite expressions like + # `x - y` to `-y + x` (e.g. apparently when `y` is a constant?) and, as a result, + # this will not be reached. We're leaving this in just in case, but we + # ultimately need to fix Aesara's canonicalizations. + + # `minuend` is measurable + return measurable_add.transform(fgraph, at.add(minuend, mul_subtrahend).owner) + + +@register_measurable_ir +@node_rewriter([exp]) +def measurable_exp(fgraph, node): + """Rewrite an `exp` node to a `MeasurableVariable`.""" + + def transform(measurable_input, *args): + return ExpTransform(), (measurable_input,) + + return construct_elemwise_transform(fgraph, node, transform) + + +@register_measurable_ir +@node_rewriter([log]) +def measurable_log(fgraph, node): + """Rewrite a `log` node to a `MeasurableVariable`.""" + + def transform(measurable_input, *args): + return LogTransform(), (measurable_input,) + + return construct_elemwise_transform(fgraph, node, transform) + + +@register_measurable_ir +@node_rewriter([add]) +def measurable_add(fgraph, node): + """Rewrite an `add` node to a `MeasurableVariable`.""" + + def transform(measurable_input, *other_inputs): + transform_inputs = ( + measurable_input, + at.add(*other_inputs) if len(other_inputs) > 1 else other_inputs[0], + ) + transform = LocTransform( + transform_args_fn=lambda *inputs: inputs[-1], + ) + return transform, transform_inputs + + return construct_elemwise_transform(fgraph, node, transform) + + +@register_measurable_ir +@node_rewriter([mul]) +def measurable_mul(fgraph, node): + """Rewrite a `mul` node to a `MeasurableVariable`.""" + + def transform(measurable_input, *other_inputs): + transform_inputs = ( + measurable_input, + at.mul(*other_inputs) if len(other_inputs) > 1 else other_inputs[0], + ) + return ( + ScaleTransform( + transform_args_fn=lambda *inputs: inputs[-1], + ), + transform_inputs, + ) + + return construct_elemwise_transform(fgraph, node, transform) + + +@register_measurable_ir +@node_rewriter([reciprocal]) +def measurable_reciprocal(fgraph, node): + """Rewrite a `reciprocal` node to a `MeasurableVariable`.""" + + def transform(measurable_input, *other_inputs): + return ReciprocalTransform(), (measurable_input,) + + return construct_elemwise_transform(fgraph, node, transform) + + +def construct_elemwise_transform( + fgraph: FunctionGraph, + node: Apply, + transform_fn: "TransformFnType", +) -> Optional[List[Variable]]: + """Construct a measurable transformation for an `Elemwise` node. + + Parameters + ---------- + fgraph + The `FunctionGraph` in which `node` resides. + node + The `Apply` node to be converted. + transform_fn + A function that takes a single measurable input and all the remaining + inputs and returns a transform object and transformed inputs. + + Returns + ------- + A new variable with an `Apply` node with a `MeasurableElemwiseTransform` + that replaces `node`. + + """ + scalar_op = node.op.scalar_op rv_map_feature: Optional[PreserveRVMappings] = getattr( fgraph, "preserve_rv_mappings", None @@ -320,11 +474,14 @@ def find_measurable_transforms( measurable_input: TensorVariable = measurable_inputs[0] # Do not apply rewrite to discrete variables + # TODO: Formalize this restriction better. if measurable_input.type.dtype.startswith("int"): return None # Check that other inputs are not potentially measurable, in which case this rewrite # would be invalid + # TODO FIXME: This is rather costly and redundant; find a way to avoid it + # or make it cheaper. other_inputs = tuple(inp for inp in node.inputs if inp is not measurable_input) if any( ancestor_node @@ -346,24 +503,9 @@ def find_measurable_transforms( measurable_input = assign_custom_measurable_outputs(measurable_input.owner) measurable_input_idx = 0 - transform_inputs: Tuple[TensorVariable, ...] = (measurable_input,) - transform: RVTransform - if isinstance(scalar_op, Exp): - transform = ExpTransform() - elif isinstance(scalar_op, Log): - transform = LogTransform() - elif isinstance(scalar_op, Add): - transform_inputs = (measurable_input, at.add(*other_inputs)) - transform = LocTransform( - transform_args_fn=lambda *inputs: inputs[-1], - ) - else: - transform_inputs = (measurable_input, at.mul(*other_inputs)) - transform = ScaleTransform( - transform_args_fn=lambda *inputs: inputs[-1], - ) + transform, transform_inputs = transform_fn(measurable_input, *other_inputs) - transform_op = MeasurableTransform( + transform_op = MeasurableElemwiseTransform( scalar_op=scalar_op, transform=transform, measurable_input_idx=measurable_input_idx, @@ -374,15 +516,6 @@ def find_measurable_transforms( return [transform_out] -measurable_ir_rewrites_db.register( - "find_measurable_transforms", - find_measurable_transforms, - 0, - "basic", - "transform", -) - - class LocTransform(RVTransform): name = "loc" @@ -446,6 +579,19 @@ def log_jac_det(self, value, *inputs): return -at.log(value) +class ReciprocalTransform(RVTransform): + name = "reciprocal" + + def forward(self, value, *inputs): + return at.reciprocal(value) + + def backward(self, value, *inputs): + return at.reciprocal(value) + + def log_jac_det(self, value, *inputs): + return -2 * at.log(value) + + class IntervalTransform(RVTransform): name = "interval" diff --git a/aeppl/utils.py b/aeppl/utils.py index cb28d194..d250167f 100644 --- a/aeppl/utils.py +++ b/aeppl/utils.py @@ -188,7 +188,9 @@ def expand_replace(var: TensorVariable) -> List[TensorVariable]: clone=False, ) - fg.replace_all(replacements.items(), import_missing=True) + fg.replace_all( + replacements.items(), import_missing=True, reason="replace_rvs_in_graphs" + ) graphs = list(fg.outputs) diff --git a/setup.cfg b/setup.cfg index 18f19783..e60be32e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,6 +33,7 @@ omit = tests/* exclude_lines = pragma: no cover + if TYPE_CHECKING: show_missing = 1 [isort] diff --git a/setup.py b/setup.py index 40f48110..26b36b0c 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ def get_versions(): "numpy>=1.18.1", "scipy>=1.4.0", "aesara >= 2.8.8", + "typing_extensions", ], tests_require=["pytest"], long_description=open("README.rst").read() if exists("README.rst") else "", diff --git a/tests/test_abstract.py b/tests/test_abstract.py index b869bd66..3806b00b 100644 --- a/tests/test_abstract.py +++ b/tests/test_abstract.py @@ -1,13 +1,8 @@ -import re - import aesara.tensor as at import pytest -from aesara.scalar import Exp, exp from aesara.tensor.random.basic import NormalRV from aeppl.abstract import ( - MeasurableElemwise, - MeasurableVariable, UnmeasurableVariable, _get_measurable_outputs, assign_custom_measurable_outputs, @@ -97,16 +92,3 @@ def test_assign_custom_measurable_outputs(): with pytest.raises(ValueError): assign_custom_measurable_outputs(unmeas_X_rv.owner, lambda x: x) - - -def test_measurable_elemwise(): - # Default does not accept any scalar_op - with pytest.raises(TypeError, match=re.escape("scalar_op exp is not valid")): - MeasurableElemwise(exp) - - class TestMeasurableElemwise(MeasurableElemwise): - valid_scalar_types = (Exp,) - - measurable_exp_op = TestMeasurableElemwise(scalar_op=exp) - measurable_exp = measurable_exp_op(0.0) - assert isinstance(measurable_exp.owner.op, MeasurableVariable) diff --git a/tests/test_censoring.py b/tests/test_censoring.py index 78a2669f..36c821fb 100644 --- a/tests/test_censoring.py +++ b/tests/test_censoring.py @@ -131,7 +131,7 @@ def test_fail_multiple_clip_single_base(): cens_rv2 = at.clip(base_rv, -1, 1) cens_rv2.name = "cens2" - with pytest.raises(RuntimeError, match="could not be derived: {cens2}"): + with pytest.raises(RuntimeError, match=r"could not be derived: {cens\d}"): conditional_logprob(cens_rv1, cens_rv2) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index db75eabd..6670404f 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -573,25 +573,32 @@ def test_log_transform_rv(): ((2, 1), at.col), ], ) -def test_loc_transform_rv(rv_size, loc_type): +@pytest.mark.parametrize("right", [True, False]) +def test_transform_measurable_add(rv_size, loc_type, right): loc = loc_type("loc") - y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv") - y_rv.name = "y" - - logps, (y_vv,) = conditional_logprob(y_rv) - logp = logps[y_rv] + X_rv = at.random.normal(0, 1, size=rv_size, name="X") + if right: + Z_rv = loc + X_rv + else: + Z_rv = X_rv + loc + + logps, (z_vv,) = conditional_logprob(Z_rv) + logp = logps[Z_rv] assert_no_rvs(logp) - logp_fn = aesara.function([loc, y_vv], logp) + logp_fn = aesara.function([loc, z_vv], logp) loc_test_val = np.full(rv_size or (), 4.0) - y_test_val = np.full(rv_size or (), 1.0) + z_test_val = np.full(rv_size or (), 1.0) np.testing.assert_allclose( - logp_fn(loc_test_val, y_test_val), - sp.stats.norm(loc_test_val, 1).logpdf(y_test_val), + logp_fn(loc_test_val, z_test_val), + sp.stats.norm(loc_test_val, 1).logpdf(z_test_val), ) + with pytest.raises(RuntimeError, match="The logprob terms"): + joint_logprob(Z_rv, X_rv) + @pytest.mark.parametrize( "rv_size, scale_type", @@ -601,25 +608,33 @@ def test_loc_transform_rv(rv_size, loc_type): ((2, 3), at.matrix), ], ) -def test_scale_transform_rv(rv_size, scale_type): +@pytest.mark.parametrize("right", [True, False]) +def test_scale_transform_rv(rv_size, scale_type, right): scale = scale_type("scale") - y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") * scale - y_rv.name = "y" - - logps, (y_vv,) = conditional_logprob(y_rv) - logp = logps[y_rv] + X_rv = at.random.normal(0, 1, size=rv_size, name="X") + if right: + Z_rv = scale * X_rv + else: + # Z_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") * scale + Z_rv = X_rv / at.reciprocal(scale) + + logps, (z_vv,) = conditional_logprob(Z_rv) + logp = logps[Z_rv] assert_no_rvs(logp) - logp_fn = aesara.function([scale, y_vv], logp) + logp_fn = aesara.function([scale, z_vv], logp) scale_test_val = np.full(rv_size or (), 4.0) - y_test_val = np.full(rv_size or (), 1.0) + z_val = np.full(rv_size or (), 1.0) np.testing.assert_allclose( - logp_fn(scale_test_val, y_test_val), - sp.stats.norm(0, scale_test_val).logpdf(y_test_val), + logp_fn(scale_test_val, z_val), + sp.stats.norm(0, scale_test_val).logpdf(z_val), ) + with pytest.raises(RuntimeError, match="The logprob terms"): + joint_logprob(Z_rv, X_rv) + def test_transformed_rv_and_value(): y_rv = at.random.halfnormal(-1, 1, name="base_rv") + 1 @@ -678,3 +693,73 @@ def test_invalid_broadcasted_transform_rv_fails(): logp, (y_vv,) = joint_logprob(y_rv) logp.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]}) assert False, "Should have failed before" + + +@pytest.mark.parametrize("a", (1.0, 2.0)) +def test_transform_measurable_true_div(a): + shape, scale = 3, 5 + X_rv = at.random.gamma(shape, scale, name="X") + + Z_rv = a / X_rv + + logp, (z_vv,) = joint_logprob(Z_rv) + z_logp_fn = aesara.function([z_vv], logp) + + z_test_val = 1.5 + assert np.isclose( + z_logp_fn(z_test_val), + sp.stats.invgamma(shape, scale=scale * a).logpdf(z_test_val), + ) + + with pytest.raises(RuntimeError, match="The logprob terms"): + joint_logprob(Z_rv, X_rv) + + Z_rv = X_rv / a + + logp, (z_vv,) = joint_logprob(Z_rv) + z_logp_fn = aesara.function([z_vv], logp) + + z_test_val = 1.5 + assert np.isclose( + z_logp_fn(z_test_val), + sp.stats.gamma(shape, scale=1 / (scale * a)).logpdf(z_test_val), + ) + + with pytest.raises(RuntimeError, match="The logprob terms"): + joint_logprob(Z_rv, X_rv) + + +def test_transform_measurable_neg(): + X_rv = at.random.halfnormal(name="X") + Z_rv = -X_rv + + logp, (z_vv,) = joint_logprob(Z_rv) + z_logp_fn = aesara.function([z_vv], logp) + + assert np.isclose(z_logp_fn(-1.5), sp.stats.halfnorm.logpdf(1.5)) + + with pytest.raises(RuntimeError, match="The logprob terms"): + joint_logprob(Z_rv, X_rv) + + +def test_transform_measurable_sub(): + # We use a base RV that is asymmetric around zero + X_rv = at.random.normal(1.0, name="X") + + Z_rv = 5.0 - X_rv + + logp, (z_vv,) = joint_logprob(Z_rv) + z_logp_fn = aesara.function([z_vv], logp) + assert np.isclose(z_logp_fn(7.3), sp.stats.norm.logpdf(5.0 - 7.3, 1.0)) + + with pytest.raises(RuntimeError, match="The logprob terms"): + joint_logprob(Z_rv, X_rv) + + Z_rv = X_rv - 5.0 + + logp, (z_vv,) = joint_logprob(Z_rv) + z_logp_fn = aesara.function([z_vv], logp) + assert np.isclose(z_logp_fn(7.3), sp.stats.norm.logpdf(7.3, loc=-4.0)) + + with pytest.raises(RuntimeError, match="The logprob terms"): + joint_logprob(Z_rv, X_rv)