diff --git a/src/grelu/lightning/losses.py b/src/grelu/lightning/losses.py index f53fd1a..0df819c 100644 --- a/src/grelu/lightning/losses.py +++ b/src/grelu/lightning/losses.py @@ -17,6 +17,8 @@ class PoissonMultinomialLoss(nn.Module): log_input: If True, the input is transformed with torch.exp to produce predicted counts. Otherwise, the input is assumed to already represent predicted counts. + multinomial_axis: Either "length" or "task", representing the axis along which the + multinomial distribution should be calculated. reduction: "mean" or "none". """ @@ -26,12 +28,17 @@ def __init__( eps: float = 1e-7, log_input: bool = True, reduction: str = "mean", + multinomial_axis: str = "length", ) -> None: super().__init__() self.eps = eps self.total_weight = total_weight self.log_input = log_input self.reduction = reduction + if multinomial_axis == "length": + self.axis = 2 + elif multinomial_axis == "task": + self.axis = 1 def forward(self, input: Tensor, target: Tensor) -> Tensor: """ @@ -54,8 +61,8 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: input += self.eps # Assuming count predictions - total_target = target.sum(axis=-1, keepdim=True) - total_input = input.sum(axis=-1, keepdim=True) + total_target = target.sum(axis=self.axis, keepdim=True) + total_input = input.sum(axis=self.axis, keepdim=True) # total count poisson loss, mean across targets poisson_term = F.poisson_nll_loss( @@ -68,11 +75,11 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: log_p_input = torch.log(p_input) # multinomial loss - multinomial_dot = -torch.multiply(target, log_p_input) # B x T x L - multinomial_term = multinomial_dot.mean(axis=-1, keepdim=True) # B x T + multinomial_dot = -torch.multiply(target, log_p_input) + multinomial_term = multinomial_dot.mean(axis=self.axis, keepdim=True) # Combine - loss = multinomial_term + self.total_weight * poisson_term + loss = multinomial_term + (self.total_weight * poisson_term) if self.reduction == "mean": return loss.mean()