Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MeasurableElemwise and its use #208

Merged
merged 4 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions aeppl/abstract.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
from copy import copy
from functools import singledispatch
from typing import Callable, List, Tuple
from typing import Callable, List

from aesara.graph.basic import Apply, Variable
from aesara.graph.op import Op
Expand Down Expand Up @@ -122,15 +122,5 @@ def assign_custom_measurable_outputs(
class MeasurableElemwise(Elemwise):
"""Base class for Measurable Elemwise variables"""

valid_scalar_types: Tuple[MetaType, ...] = ()

def __init__(self, scalar_op, *args, **kwargs):
if not isinstance(scalar_op, self.valid_scalar_types):
raise TypeError(
f"scalar_op {scalar_op} is not valid for class {self.__class__}. "
f"Acceptable types are {self.valid_scalar_types}"
)
super().__init__(scalar_op, *args, **kwargs)


MeasurableVariable.register(MeasurableElemwise)
96 changes: 55 additions & 41 deletions aeppl/censoring.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional

import aesara.tensor as at
import numpy as np
from aesara.graph.basic import Node
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import node_rewriter
from aesara.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven
from aesara.scalar.basic import ceil as scalar_ceil
from aesara.scalar.basic import clip as scalar_clip
from aesara.tensor.elemwise import Elemwise
from aesara.scalar.basic import floor as scalar_floor
from aesara.scalar.basic import round_half_to_even as scalar_round_half_to_even
from aesara.tensor.math import ceil, clip, floor, round_half_to_even
from aesara.tensor.var import TensorConstant

from aeppl.abstract import (
Expand All @@ -18,32 +20,27 @@
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob, logdiffexp
from aeppl.rewriting import measurable_ir_rewrites_db

if TYPE_CHECKING:
from aesara.graph.basic import Op, Variable


class MeasurableClip(MeasurableElemwise):
"""A placeholder used to specify a log-likelihood for a clipped RV sub-graph."""

valid_scalar_types = (Clip,)


measurable_clip = MeasurableClip(scalar_clip)


@node_rewriter(tracks=[Elemwise])
@node_rewriter([clip])
def find_measurable_clips(
fgraph: FunctionGraph, node: Node
) -> Optional[List[MeasurableClip]]:
) -> Optional[List["Variable"]]:
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableClip):
return None # pragma: no cover

if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Clip)):
return None

clipped_var = node.outputs[0]
base_var, lower_bound, upper_bound = node.inputs

Expand Down Expand Up @@ -75,7 +72,6 @@ def find_measurable_clips(
measurable_ir_rewrites_db.register(
"find_measurable_clips",
find_measurable_clips,
0,
"basic",
"censoring",
)
Expand Down Expand Up @@ -147,27 +143,55 @@ def clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
class MeasurableRound(MeasurableElemwise):
"""A placeholder used to specify a log-likelihood for a clipped RV sub-graph."""

valid_scalar_types = (RoundHalfToEven, Floor, Ceil)

measurable_ceil = MeasurableRound(scalar_ceil)
measurable_floor = MeasurableRound(scalar_floor)
measurable_round_half_to_even = MeasurableRound(scalar_round_half_to_even)

@node_rewriter(tracks=[Elemwise])
def find_measurable_roundings(
fgraph: FunctionGraph, node: Node
) -> Optional[List[MeasurableRound]]:

@node_rewriter([ceil])
def find_measurable_ceil(fgraph: FunctionGraph, node: Node):
return construct_measurable_rounding(fgraph, node, measurable_ceil)


@node_rewriter([floor])
def find_measurable_floor(fgraph: FunctionGraph, node: Node):
return construct_measurable_rounding(fgraph, node, measurable_floor)


@node_rewriter([round_half_to_even])
def find_measurable_round_half_to_even(fgraph: FunctionGraph, node: Node):
return construct_measurable_rounding(fgraph, node, measurable_round_half_to_even)


measurable_ir_rewrites_db.register(
"find_measurable_ceil",
find_measurable_ceil,
"basic",
"censoring",
)
measurable_ir_rewrites_db.register(
"find_measurable_floor",
find_measurable_floor,
"basic",
"censoring",
)
measurable_ir_rewrites_db.register(
"find_measurable_round_half_to_even",
find_measurable_round_half_to_even,
"basic",
"censoring",
)


def construct_measurable_rounding(
fgraph: FunctionGraph, node: Node, rounded_op: "Op"
) -> Optional[List["Variable"]]:

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableRound):
return None # pragma: no cover

if not (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, MeasurableRound.valid_scalar_types)
):
return None

(rounded_var,) = node.outputs
(base_var,) = node.inputs

Expand All @@ -183,21 +207,11 @@ def find_measurable_roundings(
# Make base_var unmeasurable
unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner)

rounded_op = MeasurableRound(node.op.scalar_op)
rounded_rv = rounded_op.make_node(unmeasurable_base_var).default_output()
rounded_rv.name = rounded_var.name
return [rounded_rv]


measurable_ir_rewrites_db.register(
"find_measurable_roundings",
find_measurable_roundings,
0,
"basic",
"censoring",
)


@_logprob.register(MeasurableRound)
def round_logprob(op, values, base_rv, **kwargs):
r"""Logprob of a rounded censored distribution
Expand Down Expand Up @@ -226,15 +240,15 @@ def round_logprob(op, values, base_rv, **kwargs):
"""
(value,) = values

if isinstance(op.scalar_op, RoundHalfToEven):
if op == measurable_round_half_to_even:
value = at.round(value)
value_upper = value + 0.5
value_lower = value - 0.5
elif isinstance(op.scalar_op, Floor):
elif op == measurable_floor:
value = at.floor(value)
value_upper = value + 1.0
value_lower = value
elif isinstance(op.scalar_op, Ceil):
elif op == measurable_ceil:
value = at.ceil(value)
value_upper = value
value_lower = value - 1.0
Expand Down
1 change: 0 additions & 1 deletion aeppl/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def find_measurable_cumsums(fgraph, node) -> Optional[List[MeasurableCumsum]]:
measurable_ir_rewrites_db.register(
"find_measurable_cumsums",
find_measurable_cumsums,
0,
"basic",
"cumsum",
)
2 changes: 1 addition & 1 deletion aeppl/joint_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def conditional_logprob(
vv.name = f"{rv.name}_vv"
original_rv_values[rv] = vv

# Value variables are not cloned when constructing the conditional log-proprobability
# Value variables are not cloned when constructing the conditional log-probability
# graphs. We can thus use them to recover the original random variables to index the
# maps to the logprob graphs and value variables before returning them.
rv_values = {**original_rv_values, **realized}
Expand Down
1 change: 0 additions & 1 deletion aeppl/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,6 @@ def logprob_MixtureRV(
[mixture_replace, switch_mixture_replace],
max_use_ratio=aesara.config.optdb__max_use_ratio,
),
0,
"basic",
"mixture",
)
18 changes: 6 additions & 12 deletions aeppl/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ def incsubtensor_rv_replace(fgraph, node):

logprob_rewrites_db = SequenceDB()
logprob_rewrites_db.name = "logprob_rewrites_db"
logprob_rewrites_db.register(
"pre-canonicalize", optdb.query("+canonicalize"), -10, "basic"
)
logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic")

# These rewrites convert un-measurable variables into their measurable forms,
# but they need to be reapplied, because some of the measurable forms require
Expand All @@ -229,22 +227,18 @@ def incsubtensor_rv_replace(fgraph, node):
measurable_ir_rewrites_db.name = "measurable_ir_rewrites_db"

logprob_rewrites_db.register(
"measurable_ir_rewrites", measurable_ir_rewrites_db, -10, "basic"
"measurable_ir_rewrites", measurable_ir_rewrites_db, "basic"
)

# These rewrites push random/measurable variables "down", making them closer to
# (or eventually) the graph outputs. Often this is done by lifting other `Op`s
# "up" through the random/measurable variables and into their inputs.
measurable_ir_rewrites_db.register("subtensor_lift", local_subtensor_rv_lift, "basic")
measurable_ir_rewrites_db.register(
"subtensor_lift", local_subtensor_rv_lift, -5, "basic"
)
measurable_ir_rewrites_db.register(
"incsubtensor_lift", incsubtensor_rv_replace, -5, "basic"
"incsubtensor_lift", incsubtensor_rv_replace, "basic"
)

logprob_rewrites_db.register(
"post-canonicalize", optdb.query("+canonicalize"), 10, "basic"
)
logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic")


def construct_ir_fgraph(
Expand Down Expand Up @@ -326,6 +320,6 @@ def construct_ir_fgraph(
new_to_old = tuple(
(v, k) for k, v in rv_remapper.measurable_conversions.items()
)
fgraph.replace_all(new_to_old)
fgraph.replace_all(new_to_old, reason="construct_ir_fgraph")

return fgraph, rv_values, memo
6 changes: 2 additions & 4 deletions aeppl/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,19 +513,17 @@ def _get_measurable_outputs_MeasurableScan(op, node):
# out2in(
# add_opts_to_inner_graphs, name="add_opts_to_inner_graphs", ignore_newtrees=True
# ),
-100,
"basic",
"scan",
)

measurable_ir_rewrites_db.register(
"find_measurable_scans",
find_measurable_scans,
0,
"basic",
"scan",
)

# Add scan canonicalizations that aren't in the canonicalization DB
logprob_rewrites_db.register("scan_eqopt1", scan_eqopt1, -9, "basic", "scan")
logprob_rewrites_db.register("scan_eqopt2", scan_eqopt2, -9, "basic", "scan")
logprob_rewrites_db.register("scan_eqopt1", scan_eqopt1, "basic", "scan")
logprob_rewrites_db.register("scan_eqopt2", scan_eqopt2, "basic", "scan")
7 changes: 3 additions & 4 deletions aeppl/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,25 +273,24 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[List[MeasurableDimShuf


measurable_ir_rewrites_db.register(
"dimshuffle_lift", local_dimshuffle_rv_lift, -5, "basic", "tensor"
"dimshuffle_lift", local_dimshuffle_rv_lift, "basic", "tensor"
)


# We register this later than `dimshuffle_lift` so that it is only applied as a fallback
measurable_ir_rewrites_db.register(
"find_measurable_dimshuffles", find_measurable_dimshuffles, 0, "basic", "tensor"
"find_measurable_dimshuffles", find_measurable_dimshuffles, "basic", "tensor"
)


measurable_ir_rewrites_db.register(
"broadcast_to_lift", naive_bcast_rv_lift, -5, "basic", "tensor"
"broadcast_to_lift", naive_bcast_rv_lift, "basic", "tensor"
)


measurable_ir_rewrites_db.register(
"find_measurable_stacks",
find_measurable_stacks,
0,
"basic",
"tensor",
)
Loading