Skip to content

Commit

Permalink
add tests for testing when no test is provided and when no mmodel is …
Browse files Browse the repository at this point in the history
…loaded/trained
  • Loading branch information
gcroci2 committed Oct 19, 2023
1 parent a5ad04a commit a094e2f
Showing 1 changed file with 49 additions and 4 deletions.
53 changes: 49 additions & 4 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import h5py
import pytest
import torch

from deeprank2.dataset import GraphDataset, GridDataset
from deeprank2.domain import edgestorage as Efeat
from deeprank2.domain import nodestorage as Nfeat
from deeprank2.domain import targetstorage as targets
from deeprank2.neuralnets.cnn.model3d import CnnClassification, CnnRegression
from deeprank2.neuralnets.gnn.foutnet import FoutNet
from deeprank2.neuralnets.gnn.ginet import GINet
Expand All @@ -19,10 +23,6 @@
from deeprank2.utils.exporters import (HDF5OutputExporter, ScatterPlotExporter,
TensorboardBinaryClassificationExporter)

from deeprank2.domain import edgestorage as Efeat
from deeprank2.domain import nodestorage as Nfeat
from deeprank2.domain import targetstorage as targets

_log = logging.getLogger(__name__)

default_features = [Nfeat.RESTYPE, Nfeat.POLARITY, Nfeat.BSA, Nfeat.RESDEPTH, Nfeat.HSE, Nfeat.INFOCONTENT, Nfeat.PSSM]
Expand Down Expand Up @@ -383,6 +383,31 @@ def test_incompatible_pretrained_no_Net(self):
pretrained_model = self.save_path
)

def test_no_training_no_pretrained(self):
dataset_train = GraphDataset(
hdf5_path = "tests/data/hdf5/test.hdf5",
clustering_method = "mcl",
target = targets.BINARY,
)
dataset_val = GraphDataset(
hdf5_path = "tests/data/hdf5/test.hdf5",
train = False,
dataset_train = dataset_train
)
dataset_test = GraphDataset(
hdf5_path = "tests/data/hdf5/test.hdf5",
train = False,
dataset_train = dataset_train
)
trainer = Trainer(
neuralnet = GINet,
dataset_train = dataset_train,
dataset_val = dataset_val,
dataset_test = dataset_test
)
with pytest.raises(ValueError):
trainer.test()

def test_no_valid_provided(self):
dataset = GraphDataset(
hdf5_path = "tests/data/hdf5/test.hdf5",
Expand All @@ -397,6 +422,26 @@ def test_no_valid_provided(self):
assert len(trainer.train_loader) == int(0.75 * len(dataset))
assert len(trainer.valid_loader) == int(0.25 * len(dataset))

def test_no_test_provided(self):
dataset_train = GraphDataset(
hdf5_path = "tests/data/hdf5/test.hdf5",
clustering_method = "mcl",
target = targets.BINARY,
)
dataset_val = GraphDataset(
hdf5_path = "tests/data/hdf5/test.hdf5",
train = False,
dataset_train = dataset_train
)
trainer = Trainer(
neuralnet = GINet,
dataset_train = dataset_train,
dataset_val = dataset_val,
)
trainer.train(batch_size = 1, best_model=False, filename=None)
with pytest.raises(ValueError):
trainer.test()

def test_no_valid_full_train(self):
dataset = GraphDataset(
hdf5_path = "tests/data/hdf5/test.hdf5",
Expand Down

0 comments on commit a094e2f

Please sign in to comment.