diff --git a/train.py b/train.py index bda7d72..e6c1f88 100644 --- a/train.py +++ b/train.py @@ -263,6 +263,7 @@ def train( wandb.log( { "epoch": epoch, + "global_step": global_step, "train_loss": avg_train_loss, "train_perplexity": train_perplexity, "val_loss": avg_val_loss,