Skip to content

Commit

Permalink
Merge pull request #42 from gpauloski/hyperparams
Browse files Browse the repository at this point in the history
Hyperparam callable improvements
  • Loading branch information
gpauloski authored Apr 14, 2022
2 parents 89d74a5 + 8cf9665 commit 77d256e
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 50 deletions.
2 changes: 1 addition & 1 deletion examples/cnn_utils/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_optimizer(
damping=args.kfac_damping,
factor_decay=args.kfac_factor_decay,
kl_clip=args.kfac_kl_clip,
lr=lambda: optimizer.param_groups[0]['lr'],
lr=lambda x: optimizer.param_groups[0]['lr'],
accumulation_steps=args.batches_per_allreduce,
allreduce_bucket_cap_mb=25,
colocate_factors=args.kfac_colocate_factors,
Expand Down
52 changes: 30 additions & 22 deletions kfac/base_preconditioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def __init__(
assignment: WorkAssignment,
tdc: TorchDistributedCommunicator,
# KFAC hyperparameters
factor_update_steps: Callable[[], int] | int = 1,
inv_update_steps: Callable[[], int] | int = 1,
damping: Callable[[], float] | float = 0.001,
factor_decay: Callable[[], float] | float = 0.95,
kl_clip: Callable[[], float] | float = 0.001,
lr: Callable[[], float] | float = 0.1,
factor_update_steps: Callable[[int], int] | int = 1,
inv_update_steps: Callable[[int], int] | int = 1,
damping: Callable[[int], float] | float = 0.001,
factor_decay: Callable[[int], float] | float = 0.95,
kl_clip: Callable[[int], float] | float = 0.001,
lr: Callable[[int], float] | float = 0.1,
# Other
accumulation_steps: int = 1,
update_factors_in_hook: bool = True,
Expand All @@ -54,21 +54,21 @@ def __init__(
tdc (TorchDistributedCommunicator): communicator instance.
factor_update_steps (Callable, int): steps between computing and
updating the running average of the Kronecker factors or
callable that returns the value.
callable that takes the K-FAC step and returns the value.
inv_update_steps (Callble, int): steps between recomputing and
communicating the second-order information or callable that
returns the value.
takes the K-FAC step and returns the value.
damping (Callable, float): Tikhonov damping parameter or a callable
that will return the damping parameter as a float
(default: 0.001).
that takes the K-FAC step and returns the damping parameter
as a float (default: 0.001).
factor_decay (Callable, float): running average coefficient for
Kronecker factors or callable that will return the factor_decay
(default: 0.95).
Kronecker factors or callable that takes the K-FAC step and
returns the factor_decay (default: 0.95).
kl_clip (Callable, float): clipping parameter for gradient scaling
or a callable that returns a float. If None, no
scaling/clipping will be applied (default: 0.001).
lr (Callable, float): learning rate or callable that will return
learning rate (default: 0.1).
or a callable that takes the K-FAC step and returns a float.
If None, no scaling/clipping will be applied (default: 0.001).
lr (Callable, float): learning rate or callable that takes the
K-FAC step and returns learning rate (default: 0.1).
accumulation_steps (int): number of forward/backward passes
between optimization steps (default: 1).
update_factors_in_hook (bool): If True, running average of factors
Expand Down Expand Up @@ -157,32 +157,40 @@ def __repr__(self) -> str:
@property
def damping(self) -> float:
"""Get damping value."""
return self._damping() if callable(self._damping) else self._damping
return (
self._damping(self.steps)
if callable(self._damping)
else self._damping
)

@property
def factor_decay(self) -> float:
"""Get factor decay value."""
return (
self._factor_decay()
self._factor_decay(self.steps)
if callable(self._factor_decay)
else self._factor_decay
)

@property
def kl_clip(self) -> float:
"""Get kl clip value."""
return self._kl_clip() if callable(self._kl_clip) else self._kl_clip
return (
self._kl_clip(self.steps)
if callable(self._kl_clip)
else self._kl_clip
)

@property
def lr(self) -> float:
"""Get lr value."""
return self._lr() if callable(self._lr) else self._lr
return self._lr(self.steps) if callable(self._lr) else self._lr

@property
def factor_update_steps(self) -> int:
"""Get factor update steps."""
return (
self._factor_update_steps()
self._factor_update_steps(self.steps)
if callable(self._factor_update_steps)
else self._factor_update_steps
)
Expand All @@ -191,7 +199,7 @@ def factor_update_steps(self) -> int:
def inv_update_steps(self) -> int:
"""Get inverse update steps."""
return (
self._inv_update_steps()
self._inv_update_steps(self.steps)
if callable(self._inv_update_steps)
else self._inv_update_steps
)
Expand Down
44 changes: 44 additions & 0 deletions kfac/hyperparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Common hyperparameter schedules."""
from __future__ import annotations

from typing import Callable


def exp_decay_factor_averaging(
min_value: float = 0.95,
) -> Callable[[int], float]:
"""Exponentially decaying factor averaging schedule.
Implements the running average estimate strategy for the Kronecker factors
A and G from "Optimizing Neural Networks with Kronecker-factored
Approximate Curvature" (Martens et al., 2015).
The running average weight e at K-FAC step k is min(1 - 1/k, min_value)
where the min_value is 0.95 by default.
Args:
min_value (float): minimum value for the running average weight.
Returns:
callable that takes an integer value for the current K-FAC step and
returns a float value for the running average weight. This callable
can be passed as the value of `factor_decay` to instances of
`kfac.base_preconditioner.BaseKFACPreconditioner`. Note: that if the
current step is 0, 1 / k is undefined so k = 1 will be used,
and if the current step is negative, a ValueError will be raised.
Raises:
ValueError:
if `min_value` is less than or equal to zero.
"""
if min_value <= 0:
raise ValueError('min_value must be greater than 0')

def _factor_weight(step: int) -> float:
if step < 0:
raise ValueError(f'step value cannot be negative. Got {step=}.')
if step == 0:
step = 1
return min(1 - (1 / step), min_value)

return _factor_weight
32 changes: 16 additions & 16 deletions kfac/preconditioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def __init__(
self,
model: torch.nn.Module,
*,
factor_update_steps: Callable[[], int] | int = 1,
inv_update_steps: Callable[[], int] | int = 1,
factor_update_steps: Callable[[int], int] | int = 1,
inv_update_steps: Callable[[int], int] | int = 1,
# KFAC hyperparameters
damping: Callable[[], float] | float = 0.001,
factor_decay: Callable[[], float] | float = 0.95,
kl_clip: Callable[[], float] | float = 0.001,
lr: Callable[[], float] | float = 0.1,
damping: Callable[[int], float] | float = 0.001,
factor_decay: Callable[[int], float] | float = 0.95,
kl_clip: Callable[[int], float] | float = 0.001,
lr: Callable[[int], float] | float = 0.1,
# Distribution strategy
accumulation_steps: int = 1,
allreduce_bucket_cap_mb: float = 25.0,
Expand Down Expand Up @@ -85,21 +85,21 @@ def __init__(
model (torch.nn.Module): model to precondition with KFAC.
factor_update_steps (Callable, int): steps between computing and
updating the running average of the Kronecker factors or
callable that returns the value.
callable that takes the K-FAC step and returns the value.
inv_update_steps (Callble, int): steps between recomputing and
communicating the second-order information or callable that
returns the value.
takes the K-FAC step and returns the value.
damping (Callable, float): Tikhonov damping parameter or a callable
that will return the damping parameter as a float
(default: 0.001).
that takes the K-FAC step and returns the damping parameter
as a float (default: 0.001).
factor_decay (Callable, float): running average coefficient for
Kronecker factors or callable that will return the factor_decay
(default: 0.95).
Kronecker factors or callable that takes the K-FAC step and
returns the factor_decay (default: 0.95).
kl_clip (Callable, float): clipping parameter for gradient scaling
or a callable that returns a float. If None, no
scaling/clipping will be applied (default: 0.001).
lr (Callable, float): learning rate or callable that will return
learning rate (default: 0.1).
or a callable that takes the K-FAC step and returns a float.
If None, no scaling/clipping will be applied (default: 0.001).
lr (Callable, float): learning rate or callable that takes the
K-FAC step and returns learning rate (default: 0.1).
accumulation_steps (int): number of forward/backward passes
between optimization steps (default: 1).
allreduce_bucket_cap_mb (float): maximum size in megabytes for
Expand Down
61 changes: 51 additions & 10 deletions tests/base_preconditioner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def test_base_preconditioner_init() -> None:
layers=example_layers(),
assignment=LazyAssignment(),
tdc=TorchDistributedCommunicator(),
damping=lambda: damping,
factor_decay=lambda: factor_decay,
kl_clip=lambda: kl_clip,
lr=lambda: lr,
damping=lambda x: damping,
factor_decay=lambda x: factor_decay,
kl_clip=lambda x: kl_clip,
lr=lambda x: lr,
)
assert preconditioner.damping == damping
assert preconditioner.factor_decay == factor_decay
Expand Down Expand Up @@ -220,12 +220,12 @@ def test_empty_state_dict() -> None:
layers=example_layers(),
assignment=LazyAssignment(),
tdc=TorchDistributedCommunicator(),
factor_update_steps=lambda: 1,
inv_update_steps=lambda: 3,
damping=lambda: 5,
factor_decay=lambda: 0.7,
kl_clip=lambda: 11,
lr=lambda: 13,
factor_update_steps=lambda x: 1,
inv_update_steps=lambda x: 3,
damping=lambda x: 5,
factor_decay=lambda x: 0.7,
kl_clip=lambda x: 11,
lr=lambda x: 13,
)
state_dict = p3.state_dict()
assert 'factor_update_steps' not in state_dict
Expand Down Expand Up @@ -381,3 +381,44 @@ def e2e() -> None:
assert mem == 0

e2e()


def test_base_preconditioner_callable_hyperparams() -> None:
"""Test BaseKFACPreconditioner supports callable hyperparams."""
p = BaseKFACPreconditioner(
example_layers(),
assignment=LazyAssignment(),
tdc=TorchDistributedCommunicator(),
factor_update_steps=lambda x: x * 2,
inv_update_steps=lambda x: x * 3,
damping=lambda x: x * 5,
factor_decay=lambda x: x * 7,
kl_clip=lambda x: x * 9,
)

for x in range(0, 10):
p._steps = x
assert p.factor_update_steps == x * 2
assert p.inv_update_steps == x * 3
assert p.damping == x * 5
assert p.factor_decay == x * 7
assert p.kl_clip == x * 9

p = BaseKFACPreconditioner(
example_layers(),
assignment=LazyAssignment(),
tdc=TorchDistributedCommunicator(),
factor_update_steps=lambda x: 2,
inv_update_steps=lambda x: 3,
damping=lambda x: 5,
factor_decay=lambda x: 7,
kl_clip=lambda x: 9,
)

for x in range(0, 10):
p._steps = x
assert p.factor_update_steps == 2
assert p.inv_update_steps == 3
assert p.damping == 5
assert p.factor_decay == 7
assert p.kl_clip == 9
46 changes: 46 additions & 0 deletions tests/hyperparams_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Unit tests for kfac/hyperparams.py."""
from __future__ import annotations

import pytest

from kfac.hyperparams import exp_decay_factor_averaging


def test_exp_decay_factor_averaging_types() -> None:
"""Test types and exceptions of exp_decay_factor_averaging()."""
assert callable(exp_decay_factor_averaging())
assert isinstance(exp_decay_factor_averaging()(1), float)
with pytest.raises(ValueError):
exp_decay_factor_averaging(0)
with pytest.raises(ValueError):
exp_decay_factor_averaging(-1)
with pytest.raises(ValueError):
exp_decay_factor_averaging()(-1)


def test_exp_decay_factor_averaging_non_decreasing() -> None:
"""Test exp_decay_factor_averaging() produces non decreasing values."""
func = exp_decay_factor_averaging()
values = [func(step) for step in range(1000)]
assert all(a <= b for a, b in zip(values, values[1:]))


@pytest.mark.parametrize(
'min_value,values',
(
(
0.95,
[(0, 0), (1, 0), (5, 0.8), (10, 0.9), (100, 0.95), (1000, 0.95)],
),
(0.1, [(1, 0), (10, 0.1), (100, 0.1), (1000, 0.1)]),
(1, [(1, 0), (10, 0.9), (100, 0.99)]),
),
)
def test_exp_decay_factor_averaging_values(
min_value: float,
values: list[tuple[int, float]],
) -> None:
"""Test exp_decay_factor_averaging() input/outputs."""
func = exp_decay_factor_averaging(min_value)
for step, expected_value in values:
assert func(step) == expected_value
2 changes: 1 addition & 1 deletion tests/integration/mnist_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def train_and_eval(precondition: bool, epochs: int) -> float:
model,
factor_update_steps=10,
inv_update_steps=100,
lr=lambda: optimizer.param_groups[0]['lr'],
lr=lambda x: optimizer.param_groups[0]['lr'],
update_factors_in_hook=False,
)
else:
Expand Down

0 comments on commit 77d256e

Please sign in to comment.