Skip to content

Commit

Permalink
Added global_step to check the steps we are at
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 9, 2024
1 parent d75f127 commit 64054e5
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def train(
total_loss = 0

for i, (inputs, targets) in enumerate(train_loader):
if max_steps and global_step >= max_steps:
break

# Move data to the appropriate device (CPU or GPU)
inputs, targets = (
inputs.to(accelerator.device),
Expand All @@ -207,9 +210,13 @@ def train(
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
global_step += 1

total_loss += loss.item()

if max_steps and global_step >= max_steps:
break

# Calculate average training loss and perplexity
avg_train_loss = total_loss / len(train_loader)
train_perplexity = compute_perplexity(avg_train_loss)
Expand Down

0 comments on commit 64054e5

Please sign in to comment.