From 64054e56d3f4970f48ea3218392fafbf15080653 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Mon, 8 Jul 2024 22:58:10 -0700 Subject: [PATCH] Added global_step to check the steps we are at --- train.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/train.py b/train.py index 90539f5..bda7d72 100644 --- a/train.py +++ b/train.py @@ -186,6 +186,9 @@ def train( total_loss = 0 for i, (inputs, targets) in enumerate(train_loader): + if max_steps and global_step >= max_steps: + break + # Move data to the appropriate device (CPU or GPU) inputs, targets = ( inputs.to(accelerator.device), @@ -207,9 +210,13 @@ def train( scaler.step(optimizer) scaler.update() optimizer.zero_grad() + global_step += 1 total_loss += loss.item() + if max_steps and global_step >= max_steps: + break + # Calculate average training loss and perplexity avg_train_loss = total_loss / len(train_loader) train_perplexity = compute_perplexity(avg_train_loss)