diff --git a/train.py b/train.py index c799ea6..fdbfba1 100644 --- a/train.py +++ b/train.py @@ -168,6 +168,15 @@ def train( max_patience = 3 start_epoch = 0 + # Check if resuming from a checkpoint + if resume_from_checkpoint and os.path.exists(resume_from_checkpoint): + checkpoint = accelerator.load_state(resume_from_checkpoint) + start_epoch = checkpoint.get("epoch", 0) + 1 + best_val_loss = checkpoint.get("best_val_loss", float("inf")) + logger.info( + f"Resumed training from {resume_from_checkpoint} at epoch {start_epoch}" + ) + for epoch in range(epochs): model.train() total_loss = 0