Skip to content

Commit

Permalink
Merge pull request #48 from gpauloski/assignment-bug
Browse files Browse the repository at this point in the history
Fix WorkAssignment usage in BaseKFACPreconditioner
  • Loading branch information
gpauloski authored Apr 26, 2022
2 parents cd70d9c + f87c126 commit 11f0b19
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
20 changes: 15 additions & 5 deletions kfac/base_preconditioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch

from kfac.assignment import WorkAssignment
from kfac.distributed import get_rank
from kfac.distributed import TorchDistributedCommunicator
from kfac.layers.base import KFACBaseLayer

Expand Down Expand Up @@ -336,14 +337,22 @@ def step(self) -> None:
# Compute Inverses
if self.steps % self.inv_update_steps == 0:
for name, layer in reversed(self._layers.values()):
layer.compute_a_inv(damping=self.damping)
if self._assignment.broadcast_inverses():
if get_rank() == self._assignment.inv_worker(name, 'A'):
layer.compute_a_inv(damping=self.damping)
if (
self._assignment.broadcast_inverses()
and self._assignment.is_grad_worker(name)
):
layer.broadcast_a_inv(
src=self._assignment.inv_worker(name, 'A'),
group=self._assignment.grad_worker_group(name),
)
layer.compute_g_inv(damping=self.damping)
if self._assignment.broadcast_inverses():
if get_rank() == self._assignment.inv_worker(name, 'G'):
layer.compute_g_inv(damping=self.damping)
if (
self._assignment.broadcast_inverses()
and self._assignment.is_grad_worker(name)
):
layer.broadcast_g_inv(
src=self._assignment.inv_worker(name, 'G'),
group=self._assignment.grad_worker_group(name),
Expand All @@ -352,7 +361,8 @@ def step(self) -> None:

# Compute Preconditioned Gradients
for name, layer in reversed(self._layers.values()):
layer.preconditioned_grad(damping=self.damping)
if self._assignment.is_grad_worker(name):
layer.preconditioned_grad(damping=self.damping)
if self._assignment.broadcast_gradients():
layer.broadcast_grad(
src=self._assignment.src_grad_worker(name),
Expand Down
23 changes: 15 additions & 8 deletions tests/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from testing.models import TinyModel


def train() -> None:
def train(grad_worker_frac: float) -> None:
"""Train TinyModel with KFAC on random data."""
batch_size = 4
in_features = 10
out_features = 10
epochs = 20
steps = 20

x = torch.rand(batch_size, in_features)
y = torch.rand(batch_size, out_features)
Expand All @@ -25,18 +25,21 @@ def train() -> None:
torch.distributed.all_reduce(y)

model = TinyModel()
if torch.distributed.is_initialized():
model = torch.nn.parallel.DistributedDataParallel(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
preconditioner = KFACPreconditioner(
model,
factor_update_steps=5,
inv_update_steps=10,
grad_worker_fraction=grad_worker_frac,
allreduce_bucket_cap_mb=0,
update_factors_in_hook=False,
)
criterion = torch.nn.MSELoss(reduction='sum')

losses = []
for _ in range(epochs):
for _ in range(steps):
y_pred = model(x)
loss = criterion(y_pred, y)
losses.append(loss.item())
Expand All @@ -49,19 +52,23 @@ def train() -> None:


@pytest.mark.parametrize(
'distributed,world_size',
((False, 1), (True, 1), (True, 2), (True, 4)),
'distributed,grad_worker_frac,world_size',
((False, 1, 1), (True, 0, 1), (True, 0.5, 2), (True, 0.5, 4)),
)
def test_training(distributed: bool, world_size: int) -> None:
def test_training(
distributed: bool,
grad_worker_frac: float,
world_size: int,
) -> None:
"""Test end-to-end training with KFACPreconditioner."""
if not distributed:
# Note: torch does not allow forking if autograd has been used
# in the parent process. So we perform the training is a separate
# process to keep this parent process "clean". See
# https://github.com/pytorch/pytorch/issues/69839#issuecomment-993686048
p = Process(target=train)
p = Process(target=train, args=(grad_worker_frac,))
p.start()
p.join()
else:
_train = distributed_test(world_size=world_size)(train)
_train()
_train(grad_worker_frac)

0 comments on commit 11f0b19

Please sign in to comment.