diff --git a/deeprank2/trainer.py b/deeprank2/trainer.py index 3e9f14933..15ee8cc2c 100644 --- a/deeprank2/trainer.py +++ b/deeprank2/trainer.py @@ -1,5 +1,6 @@ import copy import logging +import warnings from time import time from typing import List, Optional, Tuple, Union @@ -577,6 +578,7 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc train_losses = [] valid_losses = [] + saved_model = False if earlystop_patience or earlystop_maxgap: early_stopping = EarlyStopping(patience=earlystop_patience, maxgap=earlystop_maxgap, min_epoch=min_epoch, trace_func=_log.info) @@ -609,6 +611,7 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc if best_model: if min(valid_losses) == loss_: checkpoint_model = self._save_model() + saved_model = True self.epoch_saved_model = epoch _log.info(f'Best model saved at epoch # {self.epoch_saved_model}.') # check early stopping criteria (in validation case only) @@ -626,14 +629,19 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc "Training data is used both for learning and model selection, which will to overfitting." + "\n\tIt is preferable to use an independent training and validation data sets.") checkpoint_model = self._save_model() + saved_model = True self.epoch_saved_model = epoch _log.info(f'Best model saved at epoch # {self.epoch_saved_model}.') # Save the last model - if best_model is False: + if best_model is False or not saved_model: checkpoint_model = self._save_model() self.epoch_saved_model = epoch _log.info(f'Last model saved at epoch # {self.epoch_saved_model}.') + if not saved_model: + warnings.warn("A model has been saved but the validation and/or the training losses were NaN;" + + "\n\ttry to increase the cutoff distance during the data processing or the number of data points " + + "during the training.") # Now that the training loop is over, save the model if filename: diff --git a/tests/test_integration.py b/tests/test_integration.py index d8a37a58c..fd21413c3 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -4,6 +4,9 @@ from tempfile import mkdtemp import h5py +import pandas as pd +import pytest +import torch from deeprank2.dataset import GraphDataset, GridDataset from deeprank2.domain import edgestorage as Efeat @@ -11,6 +14,7 @@ from deeprank2.domain import targetstorage as targets from deeprank2.neuralnets.cnn.model3d import CnnClassification from deeprank2.neuralnets.gnn.ginet import GINet +from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork from deeprank2.query import (ProteinProteinInterfaceResidueQuery, QueryCollection) from deeprank2.tools.target import compute_ppi_scores @@ -196,3 +200,77 @@ def test_gnn(): # pylint: disable=too-many-locals finally: rmtree(hdf5_directory) rmtree(output_directory) + +@pytest.fixture(scope='session') +def hdf5_files_for_nan(tmpdir_factory): + # For testing cases in which the loss function is nan for the validation and/or for + # the training sets. It doesn't matter if the dataset is a GraphDataset or a GridDataset, + # since it is a functionality of the trainer module, which does not depend on the dataset type. + # The settings and the parameters have been carefully chosen to result in nan losses. + pdb_paths = [ + "tests/data/pdb/3C8P/3C8P.pdb", + "tests/data/pdb/1A0Z/1A0Z.pdb", + "tests/data/pdb/1ATN/1ATN_1w.pdb" + ] + chain_id1 = "A" + chain_id2 = "B" + targets_values = [0, 1, 1] + prefix = os.path.join(tmpdir_factory.mktemp("data"), "test-queries-process") + + queries = QueryCollection() + for idx, pdb_path in enumerate(pdb_paths): + query = ProteinProteinInterfaceResidueQuery( + pdb_path, + chain_id1, + chain_id2, + # A very low cutoff distance helps for not making the network to learn + distance_cutoff=3, + targets = {targets.BINARY: targets_values[idx]} + ) + queries.add(query) + + hdf5_paths = queries.process(prefix = prefix) + return hdf5_paths + +@pytest.mark.parametrize("validate, best_model", [(True, True), (False, True), (False, False), (True, False)]) +def test_nan_loss_cases(validate, best_model, hdf5_files_for_nan): + mols = [] + for fname in hdf5_files_for_nan: + with h5py.File(fname, 'r') as hdf5: + for mol in hdf5.keys(): + mols.append(mol) + + dataset_train = GraphDataset( + hdf5_path = hdf5_files_for_nan, + subset = mols[1:], + target = targets.BINARY, + task = targets.CLASSIF + ) + dataset_valid = GraphDataset( + hdf5_path = hdf5_files_for_nan, + subset = [mols[0]], + dataset_train=dataset_train, + train=False + ) + + trainer = Trainer( + NaiveNetwork, + dataset_train, + dataset_valid) + + optimizer = torch.optim.SGD + lr = 10000 + weight_decay = 10000 + + trainer.configure_optimizers(optimizer, lr, weight_decay=weight_decay) + w_msg = "A model has been saved but the validation and/or the training losses were NaN;" + \ + "\n\ttry to increase the cutoff distance during the data processing or the number of data points " + \ + "during the training." + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning) + trainer.train( + nepoch=5, batch_size=1, validate=validate, + best_model=best_model, filename='test.pth.tar') + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert w_msg in str(w[-1].message)