Skip to content

Commit

Permalink
Begin implementing a training edge case to pickup where the training …
Browse files Browse the repository at this point in the history
…left off if interrupeted.
  • Loading branch information
jshuadvd committed Jul 7, 2024
1 parent 3631ec6 commit 66a364e
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ def train(
max_patience = 3
start_epoch = 0

# Check if resuming from a checkpoint
if resume_from_checkpoint and os.path.exists(resume_from_checkpoint):
checkpoint = accelerator.load_state(resume_from_checkpoint)
start_epoch = checkpoint.get("epoch", 0) + 1
best_val_loss = checkpoint.get("best_val_loss", float("inf"))
logger.info(
f"Resumed training from {resume_from_checkpoint} at epoch {start_epoch}"
)

for epoch in range(epochs):
model.train()
total_loss = 0
Expand Down

0 comments on commit 66a364e

Please sign in to comment.