From 0bd0d220f7fc99b4a10e5bc4af51c8b2bff15402 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 21 Apr 2022 18:14:46 +0200 Subject: [PATCH] Add goal for normal loc-scale transformation --- aemcmc/transforms.py | 41 ++++++++++++++++++++++++++++++++++++++++ tests/test_transforms.py | 37 ++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 aemcmc/transforms.py create mode 100644 tests/test_transforms.py diff --git a/aemcmc/transforms.py b/aemcmc/transforms.py new file mode 100644 index 0000000..7d943fd --- /dev/null +++ b/aemcmc/transforms.py @@ -0,0 +1,41 @@ +import aesara.tensor as at +from etuples import etuple, etuplize +from kanren import eq, lall +from kanren.facts import Relation, fact +from unification import var + +loc_scale_family = Relation("loc-scale") +fact(loc_scale_family, at.random.normal) + +def normal_scale_loc_goal(in_expr, out_expr): + """Create a relation to lift and sink scale and location parameters of distributions.""" + + # Centered representation + rng_lv, size_lv, type_idx_lv = var(), var(), var() + mu_lv, sd_lv = var(), var() + normal_centered_et = etuple( + etuplize(at.random.normal), rng_lv, size_lv, type_idx_lv, mu_lv, sd_lv + ) + + # Non-centered representation + normal_nc_et = etuple( + etuplize(at.add), + mu_lv, + etuple( + etuplize(at.mul), + sd_lv, + etuple( + etuplize(at.random.normal), + 0.0, + 1.0, + rng=rng_lv, + size=size_lv, + dtype=type_idx_lv, + ), + ), + ) + + return lall( + eq(in_expr, normal_centered_et), + eq(out_expr, normal_nc_et), + ) diff --git a/tests/test_transforms.py b/tests/test_transforms.py new file mode 100644 index 0000000..f0619f8 --- /dev/null +++ b/tests/test_transforms.py @@ -0,0 +1,37 @@ +from functools import partial + +import aesara.tensor as at +from aesara.graph.unify import eval_if_etuple +from kanren import run +from kanren.graph import reduceo, walko +from unification import var + +from aemcmc.transforms import normal_scale_loc_goal + + +def test_normal_scale_loc_transform(): + """""" + + srng = at.random.RandomStream(0) + mu_a_rv = srng.normal(0, 1) + sigma_a_rv = srng.halfcauchy(1) + a_rv = srng.normal(mu_a_rv, sigma_a_rv, size=(10,)) + + mu_b_rv = srng.normal(0, 1) + sigma_b_rv = srng.halfcauchy(1) + b_rv = srng.normal(mu_b_rv, sigma_b_rv, size=(10)) + + mu = a_rv + b_rv + sigma_rv = srng.halfcauchy(5.0) + Y_rv = srng.normal(mu, sigma_rv) + + q_lv = var() + (expr_graph,) = run( + 1, q_lv, walko(partial(reduceo, normal_scale_loc_goal), Y_rv, q_lv) + ) + Y_nc_rv = eval_if_etuple(expr_graph) + + # Make sure that Y_rv gets replaced with an addition + assert Y_nc_rv.owner.op == at.add + rhs = Y_nc_rv.owner.inputs[1].owner.inputs[0] + assert rhs.owner.op == at.mul