From edfa4eb3046925a6f092c3e838a65aa20d6d6f8b Mon Sep 17 00:00:00 2001 From: avantikalal Date: Thu, 31 Oct 2024 20:58:17 +0000 Subject: [PATCH 1/2] add multiomial axis option --- src/grelu/lightning/losses.py | 17 ++++++++++++----- tests/test_models.py | 2 +- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/grelu/lightning/losses.py b/src/grelu/lightning/losses.py index f53fd1a..0df819c 100644 --- a/src/grelu/lightning/losses.py +++ b/src/grelu/lightning/losses.py @@ -17,6 +17,8 @@ class PoissonMultinomialLoss(nn.Module): log_input: If True, the input is transformed with torch.exp to produce predicted counts. Otherwise, the input is assumed to already represent predicted counts. + multinomial_axis: Either "length" or "task", representing the axis along which the + multinomial distribution should be calculated. reduction: "mean" or "none". """ @@ -26,12 +28,17 @@ def __init__( eps: float = 1e-7, log_input: bool = True, reduction: str = "mean", + multinomial_axis: str = "length", ) -> None: super().__init__() self.eps = eps self.total_weight = total_weight self.log_input = log_input self.reduction = reduction + if multinomial_axis == "length": + self.axis = 2 + elif multinomial_axis == "task": + self.axis = 1 def forward(self, input: Tensor, target: Tensor) -> Tensor: """ @@ -54,8 +61,8 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: input += self.eps # Assuming count predictions - total_target = target.sum(axis=-1, keepdim=True) - total_input = input.sum(axis=-1, keepdim=True) + total_target = target.sum(axis=self.axis, keepdim=True) + total_input = input.sum(axis=self.axis, keepdim=True) # total count poisson loss, mean across targets poisson_term = F.poisson_nll_loss( @@ -68,11 +75,11 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: log_p_input = torch.log(p_input) # multinomial loss - multinomial_dot = -torch.multiply(target, log_p_input) # B x T x L - multinomial_term = multinomial_dot.mean(axis=-1, keepdim=True) # B x T + multinomial_dot = -torch.multiply(target, log_p_input) + multinomial_term = multinomial_dot.mean(axis=self.axis, keepdim=True) # Combine - loss = multinomial_term + self.total_weight * poisson_term + loss = multinomial_term + (self.total_weight * poisson_term) if self.reduction == "mean": return loss.mean() diff --git a/tests/test_models.py b/tests/test_models.py index f66d32f..cade4af 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,6 +1,6 @@ import torch -import wandb +import wandb from grelu.model.models import ( BorzoiModel, BorzoiPretrainedModel, From 8bbe77abc9557354dfb518666d351d6ee133941a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 21:00:30 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index cade4af..f66d32f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,6 +1,6 @@ import torch - import wandb + from grelu.model.models import ( BorzoiModel, BorzoiPretrainedModel,