diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 5a3e7f7ca..629d409cc 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -7,6 +7,7 @@ import warnings import h5py +import pandas as pd import pytest import torch @@ -706,6 +707,89 @@ def test_invalid_no_cuda_available(self): warnings.warn('CUDA is available; test_invalid_no_cuda_available was skipped') _log.info('CUDA is available; test_invalid_no_cuda_available was skipped') + def test_train_method_no_train(self): + + # Graphs data + test_data_graph = "tests/data/hdf5/test.hdf5" + pretrained_model_graph = "tests/data/pretrained/testing_graph_model.pth.tar" + + dataset_test = GraphDataset( + hdf5_path = test_data_graph, + train = False, + train_data = pretrained_model_graph + ) + trainer = Trainer( + neuralnet = NaiveNetwork, + dataset_test = dataset_test, + pretrained_model = pretrained_model_graph + ) + + with pytest.raises(ValueError): + trainer.train() + + # Grids data + test_data_grid = "tests/data/hdf5/1ATN_ppi.hdf5" + pretrained_model_grid = "tests/data/pretrained/testing_grid_model.pth.tar" + + dataset_test = GridDataset( + hdf5_path = test_data_grid, + train = False, + train_data = pretrained_model_grid + ) + trainer = Trainer( + neuralnet = CnnClassification, + dataset_test = dataset_test, + pretrained_model = pretrained_model_grid + ) + + with pytest.raises(ValueError): + trainer.train() + + def test_test_method_pretrained_model_on_dataset_with_target(self): + + # Graphs data + test_data_graph = "tests/data/hdf5/test.hdf5" + pretrained_model_graph = "tests/data/pretrained/testing_graph_model.pth.tar" + + dataset_test = GraphDataset( + hdf5_path = test_data_graph, + train = False, + train_data = pretrained_model_graph + ) + + trainer = Trainer( + neuralnet = NaiveNetwork, + dataset_test = dataset_test, + pretrained_model = pretrained_model_graph, + output_exporters = [HDF5OutputExporter("./")] + ) + + trainer.test() + + output = pd.read_hdf("output_exporter.hdf5", key="testing") + assert len(output) == len(dataset_test) + + # Grids data + test_data_grid = "tests/data/hdf5/1ATN_ppi.hdf5" + pretrained_model_grid = "tests/data/pretrained/testing_grid_model.pth.tar" + + dataset_test = GridDataset( + hdf5_path = test_data_grid, + train = False, + train_data = pretrained_model_grid + ) + + trainer = Trainer( + neuralnet = CnnClassification, + dataset_test = dataset_test, + pretrained_model = pretrained_model_grid, + output_exporters = [HDF5OutputExporter("./")] + ) + + trainer.test() + + output = pd.read_hdf("output_exporter.hdf5", key="testing") + assert len(output) == len(dataset_test) if __name__ == "__main__":