diff --git a/examples/cnn_utils/optimizers.py b/examples/cnn_utils/optimizers.py index 985f7d31..1a549d18 100644 --- a/examples/cnn_utils/optimizers.py +++ b/examples/cnn_utils/optimizers.py @@ -59,7 +59,7 @@ def get_optimizer( damping=args.kfac_damping, factor_decay=args.kfac_factor_decay, kl_clip=args.kfac_kl_clip, - lr=lambda: optimizer.param_groups[0]['lr'], + lr=lambda x: optimizer.param_groups[0]['lr'], accumulation_steps=args.batches_per_allreduce, allreduce_bucket_cap_mb=25, colocate_factors=args.kfac_colocate_factors, diff --git a/kfac/base_preconditioner.py b/kfac/base_preconditioner.py index 2d3cc89d..2ae6c5d3 100644 --- a/kfac/base_preconditioner.py +++ b/kfac/base_preconditioner.py @@ -32,12 +32,12 @@ def __init__( assignment: WorkAssignment, tdc: TorchDistributedCommunicator, # KFAC hyperparameters - factor_update_steps: Callable[[], int] | int = 1, - inv_update_steps: Callable[[], int] | int = 1, - damping: Callable[[], float] | float = 0.001, - factor_decay: Callable[[], float] | float = 0.95, - kl_clip: Callable[[], float] | float = 0.001, - lr: Callable[[], float] | float = 0.1, + factor_update_steps: Callable[[int], int] | int = 1, + inv_update_steps: Callable[[int], int] | int = 1, + damping: Callable[[int], float] | float = 0.001, + factor_decay: Callable[[int], float] | float = 0.95, + kl_clip: Callable[[int], float] | float = 0.001, + lr: Callable[[int], float] | float = 0.1, # Other accumulation_steps: int = 1, update_factors_in_hook: bool = True, @@ -54,21 +54,21 @@ def __init__( tdc (TorchDistributedCommunicator): communicator instance. factor_update_steps (Callable, int): steps between computing and updating the running average of the Kronecker factors or - callable that returns the value. + callable that takes the K-FAC step and returns the value. inv_update_steps (Callble, int): steps between recomputing and communicating the second-order information or callable that - returns the value. + takes the K-FAC step and returns the value. damping (Callable, float): Tikhonov damping parameter or a callable - that will return the damping parameter as a float - (default: 0.001). + that takes the K-FAC step and returns the damping parameter + as a float (default: 0.001). factor_decay (Callable, float): running average coefficient for - Kronecker factors or callable that will return the factor_decay - (default: 0.95). + Kronecker factors or callable that takes the K-FAC step and + returns the factor_decay (default: 0.95). kl_clip (Callable, float): clipping parameter for gradient scaling - or a callable that returns a float. If None, no - scaling/clipping will be applied (default: 0.001). - lr (Callable, float): learning rate or callable that will return - learning rate (default: 0.1). + or a callable that takes the K-FAC step and returns a float. + If None, no scaling/clipping will be applied (default: 0.001). + lr (Callable, float): learning rate or callable that takes the + K-FAC step and returns learning rate (default: 0.1). accumulation_steps (int): number of forward/backward passes between optimization steps (default: 1). update_factors_in_hook (bool): If True, running average of factors @@ -157,13 +157,17 @@ def __repr__(self) -> str: @property def damping(self) -> float: """Get damping value.""" - return self._damping() if callable(self._damping) else self._damping + return ( + self._damping(self.steps) + if callable(self._damping) + else self._damping + ) @property def factor_decay(self) -> float: """Get factor decay value.""" return ( - self._factor_decay() + self._factor_decay(self.steps) if callable(self._factor_decay) else self._factor_decay ) @@ -171,18 +175,22 @@ def factor_decay(self) -> float: @property def kl_clip(self) -> float: """Get kl clip value.""" - return self._kl_clip() if callable(self._kl_clip) else self._kl_clip + return ( + self._kl_clip(self.steps) + if callable(self._kl_clip) + else self._kl_clip + ) @property def lr(self) -> float: """Get lr value.""" - return self._lr() if callable(self._lr) else self._lr + return self._lr(self.steps) if callable(self._lr) else self._lr @property def factor_update_steps(self) -> int: """Get factor update steps.""" return ( - self._factor_update_steps() + self._factor_update_steps(self.steps) if callable(self._factor_update_steps) else self._factor_update_steps ) @@ -191,7 +199,7 @@ def factor_update_steps(self) -> int: def inv_update_steps(self) -> int: """Get inverse update steps.""" return ( - self._inv_update_steps() + self._inv_update_steps(self.steps) if callable(self._inv_update_steps) else self._inv_update_steps ) diff --git a/kfac/hyperparams.py b/kfac/hyperparams.py new file mode 100644 index 00000000..2121fbdb --- /dev/null +++ b/kfac/hyperparams.py @@ -0,0 +1,44 @@ +"""Common hyperparameter schedules.""" +from __future__ import annotations + +from typing import Callable + + +def exp_decay_factor_averaging( + min_value: float = 0.95, +) -> Callable[[int], float]: + """Exponentially decaying factor averaging schedule. + + Implements the running average estimate strategy for the Kronecker factors + A and G from "Optimizing Neural Networks with Kronecker-factored + Approximate Curvature" (Martens et al., 2015). + + The running average weight e at K-FAC step k is min(1 - 1/k, min_value) + where the min_value is 0.95 by default. + + Args: + min_value (float): minimum value for the running average weight. + + Returns: + callable that takes an integer value for the current K-FAC step and + returns a float value for the running average weight. This callable + can be passed as the value of `factor_decay` to instances of + `kfac.base_preconditioner.BaseKFACPreconditioner`. Note: that if the + current step is 0, 1 / k is undefined so k = 1 will be used, + and if the current step is negative, a ValueError will be raised. + + Raises: + ValueError: + if `min_value` is less than or equal to zero. + """ + if min_value <= 0: + raise ValueError('min_value must be greater than 0') + + def _factor_weight(step: int) -> float: + if step < 0: + raise ValueError(f'step value cannot be negative. Got {step=}.') + if step == 0: + step = 1 + return min(1 - (1 / step), min_value) + + return _factor_weight diff --git a/kfac/preconditioner.py b/kfac/preconditioner.py index edeb90c8..5a062d55 100644 --- a/kfac/preconditioner.py +++ b/kfac/preconditioner.py @@ -49,13 +49,13 @@ def __init__( self, model: torch.nn.Module, *, - factor_update_steps: Callable[[], int] | int = 1, - inv_update_steps: Callable[[], int] | int = 1, + factor_update_steps: Callable[[int], int] | int = 1, + inv_update_steps: Callable[[int], int] | int = 1, # KFAC hyperparameters - damping: Callable[[], float] | float = 0.001, - factor_decay: Callable[[], float] | float = 0.95, - kl_clip: Callable[[], float] | float = 0.001, - lr: Callable[[], float] | float = 0.1, + damping: Callable[[int], float] | float = 0.001, + factor_decay: Callable[[int], float] | float = 0.95, + kl_clip: Callable[[int], float] | float = 0.001, + lr: Callable[[int], float] | float = 0.1, # Distribution strategy accumulation_steps: int = 1, allreduce_bucket_cap_mb: float = 25.0, @@ -85,21 +85,21 @@ def __init__( model (torch.nn.Module): model to precondition with KFAC. factor_update_steps (Callable, int): steps between computing and updating the running average of the Kronecker factors or - callable that returns the value. + callable that takes the K-FAC step and returns the value. inv_update_steps (Callble, int): steps between recomputing and communicating the second-order information or callable that - returns the value. + takes the K-FAC step and returns the value. damping (Callable, float): Tikhonov damping parameter or a callable - that will return the damping parameter as a float - (default: 0.001). + that takes the K-FAC step and returns the damping parameter + as a float (default: 0.001). factor_decay (Callable, float): running average coefficient for - Kronecker factors or callable that will return the factor_decay - (default: 0.95). + Kronecker factors or callable that takes the K-FAC step and + returns the factor_decay (default: 0.95). kl_clip (Callable, float): clipping parameter for gradient scaling - or a callable that returns a float. If None, no - scaling/clipping will be applied (default: 0.001). - lr (Callable, float): learning rate or callable that will return - learning rate (default: 0.1). + or a callable that takes the K-FAC step and returns a float. + If None, no scaling/clipping will be applied (default: 0.001). + lr (Callable, float): learning rate or callable that takes the + K-FAC step and returns learning rate (default: 0.1). accumulation_steps (int): number of forward/backward passes between optimization steps (default: 1). allreduce_bucket_cap_mb (float): maximum size in megabytes for diff --git a/tests/base_preconditioner_test.py b/tests/base_preconditioner_test.py index a7188d3a..cf3c2230 100644 --- a/tests/base_preconditioner_test.py +++ b/tests/base_preconditioner_test.py @@ -145,10 +145,10 @@ def test_base_preconditioner_init() -> None: layers=example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), - damping=lambda: damping, - factor_decay=lambda: factor_decay, - kl_clip=lambda: kl_clip, - lr=lambda: lr, + damping=lambda x: damping, + factor_decay=lambda x: factor_decay, + kl_clip=lambda x: kl_clip, + lr=lambda x: lr, ) assert preconditioner.damping == damping assert preconditioner.factor_decay == factor_decay @@ -220,12 +220,12 @@ def test_empty_state_dict() -> None: layers=example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), - factor_update_steps=lambda: 1, - inv_update_steps=lambda: 3, - damping=lambda: 5, - factor_decay=lambda: 0.7, - kl_clip=lambda: 11, - lr=lambda: 13, + factor_update_steps=lambda x: 1, + inv_update_steps=lambda x: 3, + damping=lambda x: 5, + factor_decay=lambda x: 0.7, + kl_clip=lambda x: 11, + lr=lambda x: 13, ) state_dict = p3.state_dict() assert 'factor_update_steps' not in state_dict @@ -381,3 +381,44 @@ def e2e() -> None: assert mem == 0 e2e() + + +def test_base_preconditioner_callable_hyperparams() -> None: + """Test BaseKFACPreconditioner supports callable hyperparams.""" + p = BaseKFACPreconditioner( + example_layers(), + assignment=LazyAssignment(), + tdc=TorchDistributedCommunicator(), + factor_update_steps=lambda x: x * 2, + inv_update_steps=lambda x: x * 3, + damping=lambda x: x * 5, + factor_decay=lambda x: x * 7, + kl_clip=lambda x: x * 9, + ) + + for x in range(0, 10): + p._steps = x + assert p.factor_update_steps == x * 2 + assert p.inv_update_steps == x * 3 + assert p.damping == x * 5 + assert p.factor_decay == x * 7 + assert p.kl_clip == x * 9 + + p = BaseKFACPreconditioner( + example_layers(), + assignment=LazyAssignment(), + tdc=TorchDistributedCommunicator(), + factor_update_steps=lambda x: 2, + inv_update_steps=lambda x: 3, + damping=lambda x: 5, + factor_decay=lambda x: 7, + kl_clip=lambda x: 9, + ) + + for x in range(0, 10): + p._steps = x + assert p.factor_update_steps == 2 + assert p.inv_update_steps == 3 + assert p.damping == 5 + assert p.factor_decay == 7 + assert p.kl_clip == 9 diff --git a/tests/hyperparams_test.py b/tests/hyperparams_test.py new file mode 100644 index 00000000..5144b0d1 --- /dev/null +++ b/tests/hyperparams_test.py @@ -0,0 +1,46 @@ +"""Unit tests for kfac/hyperparams.py.""" +from __future__ import annotations + +import pytest + +from kfac.hyperparams import exp_decay_factor_averaging + + +def test_exp_decay_factor_averaging_types() -> None: + """Test types and exceptions of exp_decay_factor_averaging().""" + assert callable(exp_decay_factor_averaging()) + assert isinstance(exp_decay_factor_averaging()(1), float) + with pytest.raises(ValueError): + exp_decay_factor_averaging(0) + with pytest.raises(ValueError): + exp_decay_factor_averaging(-1) + with pytest.raises(ValueError): + exp_decay_factor_averaging()(-1) + + +def test_exp_decay_factor_averaging_non_decreasing() -> None: + """Test exp_decay_factor_averaging() produces non decreasing values.""" + func = exp_decay_factor_averaging() + values = [func(step) for step in range(1000)] + assert all(a <= b for a, b in zip(values, values[1:])) + + +@pytest.mark.parametrize( + 'min_value,values', + ( + ( + 0.95, + [(0, 0), (1, 0), (5, 0.8), (10, 0.9), (100, 0.95), (1000, 0.95)], + ), + (0.1, [(1, 0), (10, 0.1), (100, 0.1), (1000, 0.1)]), + (1, [(1, 0), (10, 0.9), (100, 0.99)]), + ), +) +def test_exp_decay_factor_averaging_values( + min_value: float, + values: list[tuple[int, float]], +) -> None: + """Test exp_decay_factor_averaging() input/outputs.""" + func = exp_decay_factor_averaging(min_value) + for step, expected_value in values: + assert func(step) == expected_value diff --git a/tests/integration/mnist_integration_test.py b/tests/integration/mnist_integration_test.py index c293f39e..22acc1d4 100644 --- a/tests/integration/mnist_integration_test.py +++ b/tests/integration/mnist_integration_test.py @@ -124,7 +124,7 @@ def train_and_eval(precondition: bool, epochs: int) -> float: model, factor_update_steps=10, inv_update_steps=100, - lr=lambda: optimizer.param_groups[0]['lr'], + lr=lambda x: optimizer.param_groups[0]['lr'], update_factors_in_hook=False, ) else: