From a094e2fc9663661ba0d0aabb91b573fc0890544e Mon Sep 17 00:00:00 2001 From: gcroci2 Date: Thu, 19 Oct 2023 15:01:35 +0200 Subject: [PATCH] add tests for testing when no test is provided and when no mmodel is loaded/trained --- tests/test_trainer.py | 53 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 94fa1afcc..033f8a934 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -9,7 +9,11 @@ import h5py import pytest import torch + from deeprank2.dataset import GraphDataset, GridDataset +from deeprank2.domain import edgestorage as Efeat +from deeprank2.domain import nodestorage as Nfeat +from deeprank2.domain import targetstorage as targets from deeprank2.neuralnets.cnn.model3d import CnnClassification, CnnRegression from deeprank2.neuralnets.gnn.foutnet import FoutNet from deeprank2.neuralnets.gnn.ginet import GINet @@ -19,10 +23,6 @@ from deeprank2.utils.exporters import (HDF5OutputExporter, ScatterPlotExporter, TensorboardBinaryClassificationExporter) -from deeprank2.domain import edgestorage as Efeat -from deeprank2.domain import nodestorage as Nfeat -from deeprank2.domain import targetstorage as targets - _log = logging.getLogger(__name__) default_features = [Nfeat.RESTYPE, Nfeat.POLARITY, Nfeat.BSA, Nfeat.RESDEPTH, Nfeat.HSE, Nfeat.INFOCONTENT, Nfeat.PSSM] @@ -383,6 +383,31 @@ def test_incompatible_pretrained_no_Net(self): pretrained_model = self.save_path ) + def test_no_training_no_pretrained(self): + dataset_train = GraphDataset( + hdf5_path = "tests/data/hdf5/test.hdf5", + clustering_method = "mcl", + target = targets.BINARY, + ) + dataset_val = GraphDataset( + hdf5_path = "tests/data/hdf5/test.hdf5", + train = False, + dataset_train = dataset_train + ) + dataset_test = GraphDataset( + hdf5_path = "tests/data/hdf5/test.hdf5", + train = False, + dataset_train = dataset_train + ) + trainer = Trainer( + neuralnet = GINet, + dataset_train = dataset_train, + dataset_val = dataset_val, + dataset_test = dataset_test + ) + with pytest.raises(ValueError): + trainer.test() + def test_no_valid_provided(self): dataset = GraphDataset( hdf5_path = "tests/data/hdf5/test.hdf5", @@ -397,6 +422,26 @@ def test_no_valid_provided(self): assert len(trainer.train_loader) == int(0.75 * len(dataset)) assert len(trainer.valid_loader) == int(0.25 * len(dataset)) + def test_no_test_provided(self): + dataset_train = GraphDataset( + hdf5_path = "tests/data/hdf5/test.hdf5", + clustering_method = "mcl", + target = targets.BINARY, + ) + dataset_val = GraphDataset( + hdf5_path = "tests/data/hdf5/test.hdf5", + train = False, + dataset_train = dataset_train + ) + trainer = Trainer( + neuralnet = GINet, + dataset_train = dataset_train, + dataset_val = dataset_val, + ) + trainer.train(batch_size = 1, best_model=False, filename=None) + with pytest.raises(ValueError): + trainer.test() + def test_no_valid_full_train(self): dataset = GraphDataset( hdf5_path = "tests/data/hdf5/test.hdf5",