From 66a364e25b58226dd5b6c1f5a2750863247d7910 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Sat, 6 Jul 2024 22:23:24 -0700 Subject: [PATCH] Begin implementing a training edge case to pickup where the training left off if interrupeted. --- train.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/train.py b/train.py index c799ea6..fdbfba1 100644 --- a/train.py +++ b/train.py @@ -168,6 +168,15 @@ def train( max_patience = 3 start_epoch = 0 + # Check if resuming from a checkpoint + if resume_from_checkpoint and os.path.exists(resume_from_checkpoint): + checkpoint = accelerator.load_state(resume_from_checkpoint) + start_epoch = checkpoint.get("epoch", 0) + 1 + best_val_loss = checkpoint.get("best_val_loss", float("inf")) + logger.info( + f"Resumed training from {resume_from_checkpoint} at epoch {start_epoch}" + ) + for epoch in range(epochs): model.train() total_loss = 0