Skip to content

Commit

Permalink
Merge pull request #535 from DeepRank/hotfix_533_unboundlocalerror_gc…
Browse files Browse the repository at this point in the history
…roci2

fix: handle cases in which losses are nan during the training
  • Loading branch information
gcroci2 authored Dec 20, 2023
2 parents 599ba61 + 1b3b6b6 commit f9681aa
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 1 deletion.
10 changes: 9 additions & 1 deletion deeprank2/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import logging
import warnings
from time import time
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -577,6 +578,7 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc

train_losses = []
valid_losses = []
saved_model = False

if earlystop_patience or earlystop_maxgap:
early_stopping = EarlyStopping(patience=earlystop_patience, maxgap=earlystop_maxgap, min_epoch=min_epoch, trace_func=_log.info)
Expand Down Expand Up @@ -609,6 +611,7 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc
if best_model:
if min(valid_losses) == loss_:
checkpoint_model = self._save_model()
saved_model = True
self.epoch_saved_model = epoch
_log.info(f'Best model saved at epoch # {self.epoch_saved_model}.')
# check early stopping criteria (in validation case only)
Expand All @@ -626,14 +629,19 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc
"Training data is used both for learning and model selection, which will to overfitting." +
"\n\tIt is preferable to use an independent training and validation data sets.")
checkpoint_model = self._save_model()
saved_model = True
self.epoch_saved_model = epoch
_log.info(f'Best model saved at epoch # {self.epoch_saved_model}.')

# Save the last model
if best_model is False:
if best_model is False or not saved_model:
checkpoint_model = self._save_model()
self.epoch_saved_model = epoch
_log.info(f'Last model saved at epoch # {self.epoch_saved_model}.')
if not saved_model:
warnings.warn("A model has been saved but the validation and/or the training losses were NaN;" +
"\n\ttry to increase the cutoff distance during the data processing or the number of data points " +
"during the training.")

# Now that the training loop is over, save the model
if filename:
Expand Down
78 changes: 78 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
from tempfile import mkdtemp

import h5py
import pandas as pd
import pytest
import torch

from deeprank2.dataset import GraphDataset, GridDataset
from deeprank2.domain import edgestorage as Efeat
from deeprank2.domain import nodestorage as Nfeat
from deeprank2.domain import targetstorage as targets
from deeprank2.neuralnets.cnn.model3d import CnnClassification
from deeprank2.neuralnets.gnn.ginet import GINet
from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork
from deeprank2.query import (ProteinProteinInterfaceResidueQuery,
QueryCollection)
from deeprank2.tools.target import compute_ppi_scores
Expand Down Expand Up @@ -196,3 +200,77 @@ def test_gnn(): # pylint: disable=too-many-locals
finally:
rmtree(hdf5_directory)
rmtree(output_directory)

@pytest.fixture(scope='session')
def hdf5_files_for_nan(tmpdir_factory):
# For testing cases in which the loss function is nan for the validation and/or for
# the training sets. It doesn't matter if the dataset is a GraphDataset or a GridDataset,
# since it is a functionality of the trainer module, which does not depend on the dataset type.
# The settings and the parameters have been carefully chosen to result in nan losses.
pdb_paths = [
"tests/data/pdb/3C8P/3C8P.pdb",
"tests/data/pdb/1A0Z/1A0Z.pdb",
"tests/data/pdb/1ATN/1ATN_1w.pdb"
]
chain_id1 = "A"
chain_id2 = "B"
targets_values = [0, 1, 1]
prefix = os.path.join(tmpdir_factory.mktemp("data"), "test-queries-process")

queries = QueryCollection()
for idx, pdb_path in enumerate(pdb_paths):
query = ProteinProteinInterfaceResidueQuery(
pdb_path,
chain_id1,
chain_id2,
# A very low cutoff distance helps for not making the network to learn
distance_cutoff=3,
targets = {targets.BINARY: targets_values[idx]}
)
queries.add(query)

hdf5_paths = queries.process(prefix = prefix)
return hdf5_paths

@pytest.mark.parametrize("validate, best_model", [(True, True), (False, True), (False, False), (True, False)])
def test_nan_loss_cases(validate, best_model, hdf5_files_for_nan):
mols = []
for fname in hdf5_files_for_nan:
with h5py.File(fname, 'r') as hdf5:
for mol in hdf5.keys():
mols.append(mol)

dataset_train = GraphDataset(
hdf5_path = hdf5_files_for_nan,
subset = mols[1:],
target = targets.BINARY,
task = targets.CLASSIF
)
dataset_valid = GraphDataset(
hdf5_path = hdf5_files_for_nan,
subset = [mols[0]],
dataset_train=dataset_train,
train=False
)

trainer = Trainer(
NaiveNetwork,
dataset_train,
dataset_valid)

optimizer = torch.optim.SGD
lr = 10000
weight_decay = 10000

trainer.configure_optimizers(optimizer, lr, weight_decay=weight_decay)
w_msg = "A model has been saved but the validation and/or the training losses were NaN;" + \
"\n\ttry to increase the cutoff distance during the data processing or the number of data points " + \
"during the training."
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
trainer.train(
nepoch=5, batch_size=1, validate=validate,
best_model=best_model, filename='test.pth.tar')
assert len(w) == 1
assert issubclass(w[-1].category, UserWarning)
assert w_msg in str(w[-1].message)

0 comments on commit f9681aa

Please sign in to comment.