diff --git a/aemcmc/transforms.py b/aemcmc/transforms.py new file mode 100644 index 0000000..0121c21 --- /dev/null +++ b/aemcmc/transforms.py @@ -0,0 +1,75 @@ +import aesara.tensor as at +from etuples import etuple, etuplize +from kanren import eq, lall +from kanren.facts import Relation, fact +from unification import var + +location_scale_family = Relation("location-scale-family") +fact(location_scale_family, at.random.cauchy) +fact(location_scale_family, at.random.gumbel) +fact(location_scale_family, at.random.laplace) +fact(location_scale_family, at.random.logistic) +fact(location_scale_family, at.random.normal) + + +def location_scale_transform(in_expr, out_expr): + r"""Produce a goal that represents the action of lifting and sinking + the scale and location parameters of distributions in the location-scale + family. + + For instance + + .. math:: + + \begin{equation*} + Y \sim \operatorname{Normal}(\mu, \sigma) + \end{equation*} + + can also be written + + .. math:: + + \begin{align*} + \epsilon &\sim \operatorname{Normal}(0, 1)\\ + Y = \mu + \sigma\,\epsilon + \end{align*} + + Parameters + ---------- + in_expr + An expression that represents a random variable whose distribution belongs + to the location-scale family. + out_expr + An expression for the non-centered representation of this random variable. + + """ + + # Centered representation + rng_lv, size_lv, type_idx_lv = var(), var(), var() + mu_lv, sd_lv = var(), var() + distribution_lv = var() + centered_et = etuple(distribution_lv, rng_lv, size_lv, type_idx_lv, mu_lv, sd_lv) + + # Non-centered representation + noncentered_et = etuple( + etuplize(at.add), + mu_lv, + etuple( + etuplize(at.mul), + sd_lv, + etuple( + distribution_lv, + 0.0, + 1.0, + rng=rng_lv, + size=size_lv, + dtype=type_idx_lv, + ), + ), + ) + + return lall( + eq(in_expr, centered_et), + eq(out_expr, noncentered_et), + location_scale_family(distribution_lv), + ) diff --git a/tests/test_transforms.py b/tests/test_transforms.py new file mode 100644 index 0000000..04c16f3 --- /dev/null +++ b/tests/test_transforms.py @@ -0,0 +1,28 @@ +import aesara.tensor as at +from aesara.graph.fg import FunctionGraph +from aesara.graph.kanren import KanrenRelationSub +from aesara.graph.opt import EquilibriumOptimizer +from aesara.graph.opt_utils import optimize_graph + +from aemcmc.transforms import location_scale_transform + + +def test_normal_scale_loc_transform(): + """""" + + srng = at.random.RandomStream(0) + mu_rv = srng.normal(0, 1) + sigma_rv = srng.halfcauchy(1) + Y_rv = srng.normal(mu_rv, sigma_rv) + + fgraph = FunctionGraph(outputs=[Y_rv], clone=False) + + location_scale_opt = EquilibriumOptimizer( + [KanrenRelationSub(location_scale_transform)], max_use_ratio=10 + ) + res = optimize_graph(fgraph, include=[], custom_opt=location_scale_opt, clone=False) + + # Make sure that Y_rv gets replaced with an addition + assert res.owner.op == at.add + rhs = res.owner.inputs[1].owner.inputs[0] + assert rhs.owner.op == at.mul