-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #42 from gpauloski/hyperparams
Hyperparam callable improvements
- Loading branch information
Showing
7 changed files
with
189 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
"""Common hyperparameter schedules.""" | ||
from __future__ import annotations | ||
|
||
from typing import Callable | ||
|
||
|
||
def exp_decay_factor_averaging( | ||
min_value: float = 0.95, | ||
) -> Callable[[int], float]: | ||
"""Exponentially decaying factor averaging schedule. | ||
Implements the running average estimate strategy for the Kronecker factors | ||
A and G from "Optimizing Neural Networks with Kronecker-factored | ||
Approximate Curvature" (Martens et al., 2015). | ||
The running average weight e at K-FAC step k is min(1 - 1/k, min_value) | ||
where the min_value is 0.95 by default. | ||
Args: | ||
min_value (float): minimum value for the running average weight. | ||
Returns: | ||
callable that takes an integer value for the current K-FAC step and | ||
returns a float value for the running average weight. This callable | ||
can be passed as the value of `factor_decay` to instances of | ||
`kfac.base_preconditioner.BaseKFACPreconditioner`. Note: that if the | ||
current step is 0, 1 / k is undefined so k = 1 will be used, | ||
and if the current step is negative, a ValueError will be raised. | ||
Raises: | ||
ValueError: | ||
if `min_value` is less than or equal to zero. | ||
""" | ||
if min_value <= 0: | ||
raise ValueError('min_value must be greater than 0') | ||
|
||
def _factor_weight(step: int) -> float: | ||
if step < 0: | ||
raise ValueError(f'step value cannot be negative. Got {step=}.') | ||
if step == 0: | ||
step = 1 | ||
return min(1 - (1 / step), min_value) | ||
|
||
return _factor_weight |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""Unit tests for kfac/hyperparams.py.""" | ||
from __future__ import annotations | ||
|
||
import pytest | ||
|
||
from kfac.hyperparams import exp_decay_factor_averaging | ||
|
||
|
||
def test_exp_decay_factor_averaging_types() -> None: | ||
"""Test types and exceptions of exp_decay_factor_averaging().""" | ||
assert callable(exp_decay_factor_averaging()) | ||
assert isinstance(exp_decay_factor_averaging()(1), float) | ||
with pytest.raises(ValueError): | ||
exp_decay_factor_averaging(0) | ||
with pytest.raises(ValueError): | ||
exp_decay_factor_averaging(-1) | ||
with pytest.raises(ValueError): | ||
exp_decay_factor_averaging()(-1) | ||
|
||
|
||
def test_exp_decay_factor_averaging_non_decreasing() -> None: | ||
"""Test exp_decay_factor_averaging() produces non decreasing values.""" | ||
func = exp_decay_factor_averaging() | ||
values = [func(step) for step in range(1000)] | ||
assert all(a <= b for a, b in zip(values, values[1:])) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'min_value,values', | ||
( | ||
( | ||
0.95, | ||
[(0, 0), (1, 0), (5, 0.8), (10, 0.9), (100, 0.95), (1000, 0.95)], | ||
), | ||
(0.1, [(1, 0), (10, 0.1), (100, 0.1), (1000, 0.1)]), | ||
(1, [(1, 0), (10, 0.9), (100, 0.99)]), | ||
), | ||
) | ||
def test_exp_decay_factor_averaging_values( | ||
min_value: float, | ||
values: list[tuple[int, float]], | ||
) -> None: | ||
"""Test exp_decay_factor_averaging() input/outputs.""" | ||
func = exp_decay_factor_averaging(min_value) | ||
for step, expected_value in values: | ||
assert func(step) == expected_value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters