Skip to content

Commit

Permalink
Merge pull request #79 from Genentech/multinomial_axis
Browse files Browse the repository at this point in the history
add multiomial axis option to poisson multinomial loss
  • Loading branch information
avantikalal authored Nov 4, 2024
2 parents a1c3be4 + 8bbe77a commit c954dd0
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions src/grelu/lightning/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class PoissonMultinomialLoss(nn.Module):
log_input: If True, the input is transformed with torch.exp to produce predicted
counts. Otherwise, the input is assumed to already represent predicted
counts.
multinomial_axis: Either "length" or "task", representing the axis along which the
multinomial distribution should be calculated.
reduction: "mean" or "none".
"""

Expand All @@ -26,12 +28,17 @@ def __init__(
eps: float = 1e-7,
log_input: bool = True,
reduction: str = "mean",
multinomial_axis: str = "length",
) -> None:
super().__init__()
self.eps = eps
self.total_weight = total_weight
self.log_input = log_input
self.reduction = reduction
if multinomial_axis == "length":
self.axis = 2
elif multinomial_axis == "task":
self.axis = 1

def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""
Expand All @@ -54,8 +61,8 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
input += self.eps

# Assuming count predictions
total_target = target.sum(axis=-1, keepdim=True)
total_input = input.sum(axis=-1, keepdim=True)
total_target = target.sum(axis=self.axis, keepdim=True)
total_input = input.sum(axis=self.axis, keepdim=True)

# total count poisson loss, mean across targets
poisson_term = F.poisson_nll_loss(
Expand All @@ -68,11 +75,11 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
log_p_input = torch.log(p_input)

# multinomial loss
multinomial_dot = -torch.multiply(target, log_p_input) # B x T x L
multinomial_term = multinomial_dot.mean(axis=-1, keepdim=True) # B x T
multinomial_dot = -torch.multiply(target, log_p_input)
multinomial_term = multinomial_dot.mean(axis=self.axis, keepdim=True)

# Combine
loss = multinomial_term + self.total_weight * poisson_term
loss = multinomial_term + (self.total_weight * poisson_term)

if self.reduction == "mean":
return loss.mean()
Expand Down

0 comments on commit c954dd0

Please sign in to comment.