Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update the main() function to make the training process more robust a…
…nd easier to monitor. Key differences and improvements: Weights & Biases Integration: The updated function initializes wandb for logging and visualization. This allows for better tracking of training progress and results. Batch Size: The batch size is reduced from 32 to 8. This change might be to accommodate larger models or to work better with gradient accumulation. Optimizer: Changed from Adam to AdamW. AdamW provides better weight decay handling, which can help with regularization. Learning Rate Scheduler: Added a CosineAnnealingLR scheduler. This can help in better convergence by adjusting the learning rate over time. Accelerator Preparation: The scheduler is now also prepared with the accelerator. This ensures that the scheduler works correctly with distributed training setups. Model Extension: The extended_model is now used to call recover_short_context instead of the original model. This ensures that the short context recovery is performed on the extended model. Final Training: The train function now includes the scheduler as an argument. This allows the learning rate to be adjusted during training. Wandb Finish: The wandb.finish() call at the end ensures that all logs are properly synced and the run is closed.
- Loading branch information