Skip to content

Commit

Permalink
Merge pull request #8 from Genentech/suragnair-patch-1
Browse files Browse the repository at this point in the history
Reset val/test metrics and losses after every epoch
  • Loading branch information
avantikalal authored Jul 10, 2024
2 parents 174cb10 + a2cebed commit 1477f4c
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/grelu/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ def on_validation_epoch_end(self):
self.log_dict(mean_val_metrics)
self.log("val_loss", mean_losses)

self.val_metrics.reset()
self.val_losses = []

def test_step(self, batch: Tensor, batch_idx: int) -> Tensor:
"""
Calculate metrics after a single test step
Expand All @@ -368,6 +371,9 @@ def on_test_epoch_end(self) -> None:
losses = torch.stack(self.test_losses)
self.log("test_loss", torch.mean(losses))

self.test_metrics.reset()
self.test_losses = []

def configure_optimizers(self) -> None:
"""
Configure oprimizer for training
Expand Down Expand Up @@ -549,6 +555,7 @@ def train_on_dataset(
if checkpoint_path is None:
# First validation pass
trainer.validate(model=self, dataloaders=val_dataloader)
self.val_metrics.reset()

# Add data parameters
self.data_params["tasks"] = train_dataset.tasks.reset_index(
Expand Down

0 comments on commit 1477f4c

Please sign in to comment.