diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 7b269f3f5..25dabe2b1 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1019,6 +1019,11 @@ def test_inherit_info_pretrained_model_graphdataset(self): # in the test should be inherited from the pre-trained model inherited_params = ["node_features", "edge_features", "features_transform", "target", "target_transform", "task", "classes", "classes_to_index"] data = torch.load(pretrained_model, pickle_module = dill, map_location=torch.device('cpu')) + if data["features_transform"]: + for _, key in data["features_transform"].items(): + if key['transform'] is None: + continue + key['transform'] = eval(key['transform']) dataset_test_vars = vars(dataset_test) for param in inherited_params: