Skip to content

Commit

Permalink
Add gradient clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 13, 2024
1 parent f07081c commit 440f184
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@


# %%
# Streaming dataset for long sequences
class StreamingDataset(IterableDataset):
def __init__(self, dataset, tokenizer, max_length, overlap):
self.dataset = dataset
Expand Down Expand Up @@ -244,6 +245,10 @@ def train(
scaler.scale(loss).backward()

if (i + 1) % gradient_accumulation_steps == 0:
# Gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# Update weights and reset gradients
scaler.step(optimizer)
scaler.update()
Expand Down

0 comments on commit 440f184

Please sign in to comment.