Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add relation for recentering and rescaling #31

Merged
merged 2 commits into from
Jul 7, 2022

Conversation

rlouf
Copy link
Member

@rlouf rlouf commented Apr 24, 2022

In this PR we add a relation that represents loc-scale transformations (the relation between the "centered" and "non-centered" parametrization). The performance of some sampling algorithms (eg HMC) can be affected by the parametrization of the model provided by the user (see this example).

We would like aemcmc to be aware of the different parametrizations so users don't have to worry of these implementation details.

  • Create a relation for the normal loc-scale transformation
  • Add test with multiple, nested substitutions
  • Generalize to several distributions in the location-scale family

I see that the Symbolic-PyMC implementation excludes the cases where loc=0 and sd=1 but I am not sure that it is necessary since the relation still holds in this case (even though pointless but algebraic simplification after transformation should take care of this).

Note : We should implement the location and scale transformations separately: distributions like the exponential distributions are part of the scale but not location-scale family.

Closes #5

@rlouf rlouf changed the base branch from scale-loc-transform to main April 24, 2022 12:51
@rlouf rlouf added enhancement New feature or request miniKanren labels Apr 25, 2022
@rlouf rlouf self-assigned this Apr 25, 2022
@rlouf rlouf force-pushed the scale-loc-transform branch from 0bd0d22 to 9d24e90 Compare April 25, 2022 07:49
@codecov
Copy link

codecov bot commented Apr 25, 2022

Codecov Report

Merging #31 (0da0387) into main (5287dd8) will not change coverage.
The diff coverage is 100.00%.

@@            Coverage Diff            @@
##              main       #31   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files            4         5    +1     
  Lines          241       313   +72     
  Branches        19        20    +1     
=========================================
+ Hits           241       313   +72     
Impacted Files Coverage Δ
aemcmc/transforms.py 100.00% <100.00%> (ø)
aemcmc/gibbs.py 100.00% <0.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5287dd8...0da0387. Read the comment docs.

@rlouf rlouf force-pushed the scale-loc-transform branch 2 times, most recently from d64bf4b to 7fddeca Compare June 7, 2022 10:53
@rlouf
Copy link
Member Author

rlouf commented Jun 7, 2022

I tried to use KanrenRelationSub, but I get the following exception when running the test:

self = <aesara.graph.opt.EquilibriumOptimizer object at 0x7f6d2bbd3880>
fgraph = FunctionGraph(normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F6D2BB1B140>),...Generator(PCG64) at 0x7F6D2BB1AC00>), TensorConstant{[]}, TensorConstant{11}, TensorConstant{1}, TensorConstant{1.0})))
node = normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F6D2BB1B140>), TensorConstant{[]}, TensorConstant{11}, normal_rv{0, (0, 0), floatX, False}.out, halfcauchy_rv{0, (0, 0), floatX, False}.out)
lopt = <aesara.graph.kanren.KanrenRelationSub object at 0x7f6d2bbd0df0>

    def process_node(self, fgraph, node, lopt=None):
        r"""Apply `lopt` to `node`.
    
        The :meth:`lopt.transform` method will return either ``False`` or a
        list of `Variable`\s that are intended to replace :attr:`node.outputs`.
    
        If the `fgraph` accepts the replacement, then the optimization is
        successful, and this function returns ``True``.
    
        If there are no replacement candidates or the `fgraph` rejects the
        replacements, this function returns ``False``.
    
        Parameters
        ----------
        fgraph :
            A `FunctionGraph`.
        node :
            An `Apply` instance in `fgraph`
        lopt :
            A `LocalOptimizer` instance that may have a better idea for
            how to compute node's outputs.
    
        Returns
        -------
        bool
            ``True`` iff the `node`'s outputs were replaced in the `fgraph`.
    
        """
        lopt = lopt or self.local_opt
        try:
            replacements = lopt.transform(fgraph, node)
        except Exception as e:
            if self.failure_callback is not None:
                self.failure_callback(
                    e, self, [(x, None) for x in node.outputs], lopt, node
                )
                return False
            else:
                raise
        if replacements is False or replacements is None:
            return False
        old_vars = node.outputs
        remove = []
        if isinstance(replacements, dict):
            if "remove" in replacements:
                remove = replacements.pop("remove")
            old_vars = list(replacements.keys())
            replacements = list(replacements.values())
        elif not isinstance(replacements, (tuple, list)):
            raise TypeError(
                f"Local optimizer {lopt} gave wrong type of replacement. "
                f"Expected list or tuple; got {replacements}"
            )
        if len(old_vars) != len(replacements):
>           raise ValueError(
                f"Local optimizer {lopt} gave wrong number of replacements"
            )
E           ValueError: Local optimizer <aesara.graph.kanren.KanrenRelationSub object at 0x7f6d2bbd0df0> gave wrong number of replacements

The current value of the node is

node
# normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F6D2BB1B140>), TensorConstant{[]}, TensorConstant{11}, normal_rv{0, (0, 0), floatX, False}.out, halfcauchy_rv{0, (0, 0), floatX, False}.out)

kanren seems to have successully matched the input expression with its non-centered equivalent:

replacements
# [Elemwise{add,no_inplace}.0]
replacements[0].owner
# Elemwise{add,no_inplace}(normal_rv{0, (0, 0), floatX, False}.out, Elemwise{mul,no_inplace}.0)
replacements[0].owner.inputs
# [normal_rv{0, (0, 0), floatX, False}.out, Elemwise{mul,no_inplace}.0]
replacements[0].owner.inputs[1].owner
# Elemwise{mul,no_inplace}(halfcauchy_rv{0, (0, 0), floatX, False}.out, normal_rv{0, (0, 0), floatX, False}.out)

So the issue seems to come from the fact that RandomVariables outputs (contained in old_vars in the function that raises the exception) is a tuple.

It looks like this is something that should be fixed in aesara ? I understand why the rng is currently part of the RandomVariable Op's outputs, but is there another way to deal with this issue than checking the type in EquilibriumOptimizer?

fact(location_scale_family, at.random.laplace)
fact(location_scale_family, at.random.logistic)
fact(location_scale_family, at.random.normal)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this list meant to be exhaustive? If so then I'd add the uniform and t distributions as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added the ones on the top of my head when I wrote the code, but good point

Copy link
Member

@brandonwillard brandonwillard Jun 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget, a lot of these distributions are given by more general distribution(s) (e.g. generalized hyperbolic distributions), so we could simplify the recentering/rescaling rewrite implementations via the use of intermediate rewrites that convert them to their generalized distribution(s).

@rlouf rlouf force-pushed the scale-loc-transform branch from 7fddeca to ff705b8 Compare June 16, 2022 10:24
@rlouf
Copy link
Member Author

rlouf commented Jun 16, 2022

I fixed the problem related to the fact that RandomVariable Ops have multiple outputs, and the substitution works in the centered -> non-centered direction.

It does not work yet in the other direction, I added a test that currently fails. This may be another occurence of the issue we observed in #29, in which case I will mark this test as xfail, merge and fix the issue upstream in aesara.

@rlouf rlouf force-pushed the scale-loc-transform branch 2 times, most recently from ef61402 to a9379aa Compare June 16, 2022 12:51
@rlouf
Copy link
Member Author

rlouf commented Jun 16, 2022

The problem indeed comes from what we observed in #29. I'm opening an issue on aesara, we can merge this with the test marked as xfail.

@rlouf rlouf force-pushed the scale-loc-transform branch from a9379aa to 4fcf34d Compare June 16, 2022 13:40
tests/test_transforms.py Outdated Show resolved Hide resolved
tests/test_transforms.py Outdated Show resolved Hide resolved
aemcmc/transforms.py Outdated Show resolved Hide resolved
aemcmc/transforms.py Outdated Show resolved Hide resolved
aemcmc/transforms.py Outdated Show resolved Hide resolved
@rlouf rlouf force-pushed the scale-loc-transform branch from 6aad4d7 to 0da0387 Compare June 23, 2022 11:54
@rlouf rlouf marked this pull request as ready for review June 23, 2022 12:03
@rlouf rlouf requested a review from brandonwillard June 23, 2022 12:03
@rlouf
Copy link
Member Author

rlouf commented Jul 2, 2022

@brandonwillard this is ready to merge. I will open a PR to make sure that the backward transformation works here and for the beta-binomial transformation once aesara-devs/aesara#1002 has been addressed.

@rlouf
Copy link
Member Author

rlouf commented Jul 5, 2022

We can remove the xfail mark once aesara-devs/aesara#1002 is merged.

@brandonwillard brandonwillard merged commit f43519b into aesara-devs:main Jul 7, 2022
@rlouf rlouf deleted the scale-loc-transform branch August 9, 2022 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for recentering and rescaling
3 participants