Skip to content

Commit

Permalink
Log the gradient norm
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 6, 2024
1 parent 6f20445 commit ee2e5f3
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,16 @@ def train(
f"Passkey retrieval accuracy at {length} tokens: {accuracy:.2f}"
)

# Log gradient norm
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm**0.5
wandb.log({"gradient_norm": total_norm})
logger.info(f"Gradient norm: {total_norm:.4f}")

# Log metrics
wandb.log(
{
Expand Down

0 comments on commit ee2e5f3

Please sign in to comment.