Skip to content

Commit

Permalink
add trainer tests for testing without defining the dataset_train
Browse files Browse the repository at this point in the history
  • Loading branch information
gcroci2 committed Oct 23, 2023
1 parent 11f826e commit 0fcea3a
Showing 1 changed file with 84 additions and 0 deletions.
84 changes: 84 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings

import h5py
import pandas as pd
import pytest
import torch

Expand Down Expand Up @@ -706,6 +707,89 @@ def test_invalid_no_cuda_available(self):
warnings.warn('CUDA is available; test_invalid_no_cuda_available was skipped')
_log.info('CUDA is available; test_invalid_no_cuda_available was skipped')

def test_train_method_no_train(self):

# Graphs data
test_data_graph = "tests/data/hdf5/test.hdf5"
pretrained_model_graph = "tests/data/pretrained/testing_graph_model.pth.tar"

dataset_test = GraphDataset(
hdf5_path = test_data_graph,
train = False,
train_data = pretrained_model_graph
)
trainer = Trainer(
neuralnet = NaiveNetwork,
dataset_test = dataset_test,
pretrained_model = pretrained_model_graph
)

with pytest.raises(ValueError):
trainer.train()

# Grids data
test_data_grid = "tests/data/hdf5/1ATN_ppi.hdf5"
pretrained_model_grid = "tests/data/pretrained/testing_grid_model.pth.tar"

dataset_test = GridDataset(
hdf5_path = test_data_grid,
train = False,
train_data = pretrained_model_grid
)
trainer = Trainer(
neuralnet = CnnClassification,
dataset_test = dataset_test,
pretrained_model = pretrained_model_grid
)

with pytest.raises(ValueError):
trainer.train()

def test_test_method_pretrained_model_on_dataset_with_target(self):

# Graphs data
test_data_graph = "tests/data/hdf5/test.hdf5"
pretrained_model_graph = "tests/data/pretrained/testing_graph_model.pth.tar"

dataset_test = GraphDataset(
hdf5_path = test_data_graph,
train = False,
train_data = pretrained_model_graph
)

trainer = Trainer(
neuralnet = NaiveNetwork,
dataset_test = dataset_test,
pretrained_model = pretrained_model_graph,
output_exporters = [HDF5OutputExporter("./")]
)

trainer.test()

output = pd.read_hdf("output_exporter.hdf5", key="testing")
assert len(output) == len(dataset_test)

# Grids data
test_data_grid = "tests/data/hdf5/1ATN_ppi.hdf5"
pretrained_model_grid = "tests/data/pretrained/testing_grid_model.pth.tar"

dataset_test = GridDataset(
hdf5_path = test_data_grid,
train = False,
train_data = pretrained_model_grid
)

trainer = Trainer(
neuralnet = CnnClassification,
dataset_test = dataset_test,
pretrained_model = pretrained_model_grid,
output_exporters = [HDF5OutputExporter("./")]
)

trainer.test()

output = pd.read_hdf("output_exporter.hdf5", key="testing")
assert len(output) == len(dataset_test)


if __name__ == "__main__":
Expand Down

0 comments on commit 0fcea3a

Please sign in to comment.