From 7fcc03382a344d7a645725114de6a5353710e5f7 Mon Sep 17 00:00:00 2001 From: gcroci2 Date: Mon, 23 Oct 2023 10:59:59 +0200 Subject: [PATCH] add classes_to_index as inherited param and to the pre-trained model --- deeprank2/dataset.py | 7 ++++--- deeprank2/trainer.py | 2 ++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/deeprank2/dataset.py b/deeprank2/dataset.py index 5c091e6ef..f076c87c1 100644 --- a/deeprank2/dataset.py +++ b/deeprank2/dataset.py @@ -484,7 +484,7 @@ def __init__( # pylint: disable=too-many-arguments for k, v in inspect.signature(self.__init__).parameters.items() if v.default is not inspect.Parameter.empty } - + self.default_vars["classes_to_index"] = None self.features = features self.target_transform = target_transform self._check_features() @@ -510,7 +510,7 @@ def __init__( # pylint: disable=too-many-arguments Please provide a valid training GridDataset or the path to a valid DeepRank2 pre-trained model.""") #check inherited parameter with the ones in the training set - inherited_params = ["features", "target", "target_transform", "task", "classes"] + inherited_params = ["features", "target", "target_transform", "task", "classes", "classes_to_index"] self._check_inherited_params(inherited_params, data) elif train and train_data: @@ -737,6 +737,7 @@ def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-local for k, v in inspect.signature(self.__init__).parameters.items() if v.default is not inspect.Parameter.empty } + self.default_vars["classes_to_index"] = None self.node_features = node_features self.edge_features = edge_features self.clustering_method = clustering_method @@ -769,7 +770,7 @@ def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-local Please provide a valid training GraphDataset or the path to a valid DeepRank2 pre-trained model.""") #check inherited parameter with the ones in the training set - inherited_params = ["node_features", "edge_features", "features_transform", "target", "target_transform", "task", "classes"] + inherited_params = ["node_features", "edge_features", "features_transform", "target", "target_transform", "task", "classes", "classes_to_index"] self._check_inherited_params(inherited_params, data) elif train and train_data: diff --git a/deeprank2/trainer.py b/deeprank2/trainer.py index 536015029..ed7fdfb6a 100644 --- a/deeprank2/trainer.py +++ b/deeprank2/trainer.py @@ -880,6 +880,7 @@ def _load_params(self): self.target_transform = state["target_transform"] self.task = state["task"] self.classes = state["classes"] + self.classes_to_index = state["classes_to_index"] self.class_weights = state["class_weights"] self.batch_size_train = state["batch_size_train"] self.batch_size_test = state["batch_size_test"] @@ -917,6 +918,7 @@ def _save_model(self): "target_transform": self.target_transform, "task": self.task, "classes": self.classes, + "classes_to_index": self.classes_to_index, "class_weights": self.class_weights, "batch_size_train": self.batch_size_train, "batch_size_test": self.batch_size_test,