Skip to content

Commit

Permalink
improve logic for warning about nan losses
Browse files Browse the repository at this point in the history
  • Loading branch information
gcroci2 committed Dec 19, 2023
1 parent bbdf64c commit 992bc48
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions deeprank2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc

train_losses = []
valid_losses = []
checkpoint_model = None
saved_model = False

if earlystop_patience or earlystop_maxgap:
early_stopping = EarlyStopping(patience=earlystop_patience, maxgap=earlystop_maxgap, min_epoch=min_epoch, trace_func=_log.info)
Expand Down Expand Up @@ -610,6 +610,7 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc
if best_model:
if min(valid_losses) == loss_:
checkpoint_model = self._save_model()
saved_model = True
self.epoch_saved_model = epoch
_log.info(f'Best model saved at epoch # {self.epoch_saved_model}.')
# check early stopping criteria (in validation case only)
Expand All @@ -627,17 +628,19 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc
"Training data is used both for learning and model selection, which will to overfitting." +
"\n\tIt is preferable to use an independent training and validation data sets.")
checkpoint_model = self._save_model()
saved_model = True
self.epoch_saved_model = epoch
_log.info(f'Best model saved at epoch # {self.epoch_saved_model}.')

# Save the last model
if best_model is False or checkpoint_model is None:
if best_model is False or not saved_model:
checkpoint_model = self._save_model()
self.epoch_saved_model = epoch
_log.info(f'Last model saved at epoch # {self.epoch_saved_model}.')
if checkpoint_model is None:
_log.warning("A model has been saved but the training and validation losses were NaN;" +
"make sure that you are using enough data points during the trainig.")
if not saved_model:
_log.warning("A model has been saved but the validation and/or the training losses were NaN;" +
"try to increase the cutoff distance during the data processing or the number of data points" +
"during the training.")

# Now that the training loop is over, save the model
if filename:
Expand Down

0 comments on commit 992bc48

Please sign in to comment.