diff --git a/aemcmc/transforms.py b/aemcmc/transforms.py new file mode 100644 index 0000000..aed2ed8 --- /dev/null +++ b/aemcmc/transforms.py @@ -0,0 +1,75 @@ +import aesara.tensor as at +from etuples import etuple, etuplize +from kanren import eq, lall +from kanren.facts import Relation, fact +from unification import var + +location_scale_family = Relation("location-scale-family") +fact(location_scale_family, at.random.cauchy) +fact(location_scale_family, at.random.gumbel) +fact(location_scale_family, at.random.laplace) +fact(location_scale_family, at.random.logistic) +fact(location_scale_family, at.random.normal) + + +def location_scale_transform(in_expr, out_expr): + r"""Produce a goal that represents the action of lifting and sinking + the scale and location parameters of distributions in the location-scale + family. + + For instance + + .. math:: + + \begin{equation*} + Y \sim \operatorname{Normal}(\mu, \sigma) + \end{equation*} + + can also be written + + .. math:: + + \begin{align*} + \epsilon &\sim \operatorname{Normal}(0, 1)\\ + Y = \mu + \sigma\,\epsilon + \end{align*} + + Parameters + ---------- + in_expr + An expression that represents a random variable whose distribution belongs + to the location-scale family. + out_expr + An expression for the non-centered representation of this random variable. + + """ + + # Centered representation + rng_lv, size_lv, type_idx_lv = var(), var(), var() + mu_lv, sd_lv = var(), var() + distribution_lv = var() + centered_et = etuple(distribution_lv, rng_lv, size_lv, type_idx_lv, mu_lv, sd_lv) + + # Non-centered representation + noncentered_et = etuple( + etuplize(at.add), + mu_lv, + etuple( + etuplize(at.mul), + sd_lv, + etuple( + distribution_lv, + 0, + 1, + rng=rng_lv, + size=size_lv, + dtype=type_idx_lv, + ), + ), + ) + + return lall( + eq(in_expr, centered_et), + eq(out_expr, noncentered_et), + location_scale_family(distribution_lv), + ) diff --git a/tests/test_transforms.py b/tests/test_transforms.py new file mode 100644 index 0000000..2e3786f --- /dev/null +++ b/tests/test_transforms.py @@ -0,0 +1,45 @@ +import aesara.tensor as at +from aesara.graph.fg import FunctionGraph +from aesara.graph.kanren import KanrenRelationSub + +from aemcmc.transforms import location_scale_transform + + +def test_normal_scale_loc_transform_lift(): + """ "Lift the loc and scale parameters""" + + srng = at.random.RandomStream(0) + mu_rv = srng.halfnormal(1.0) + sigma_rv = srng.halfcauchy(1) + Y_rv = srng.normal(mu_rv, sigma_rv) + + fgraph = FunctionGraph(outputs=[Y_rv], clone=False) + res = KanrenRelationSub(location_scale_transform).transform( + fgraph, fgraph.outputs[0].owner + )[0] + + # Make sure that Y_rv gets replaced with an addition + assert res.owner.op == at.add + lhs = res.owner.inputs[0] + assert isinstance(lhs.owner.op, type(at.random.halfnormal)) + rhs = res.owner.inputs[1] + assert rhs.owner.op == at.mul + assert isinstance(rhs.owner.inputs[0].owner.op, type(at.random.halfcauchy)) + assert isinstance(rhs.owner.inputs[1].owner.op, type(at.random.normal)) + + +def test_normal_scale_loc_transform_sink(): + """Sink the loc and scale parameters.""" + + srng = at.random.RandomStream(0) + mu_rv = srng.halfnormal(1.0) + sigma_rv = srng.halfcauchy(1) + std_normal_rv = srng.normal(0, 1) + Y_at = mu_rv + sigma_rv * std_normal_rv + + fgraph = FunctionGraph(outputs=[Y_at], clone=False) + res = KanrenRelationSub(lambda x, y: location_scale_transform(y, x)).transform( + fgraph, fgraph.outputs[0].owner + )[0] + + assert isinstance(res.owner.op, type(at.random.normal))