Skip to content

Commit

Permalink
add classes_to_index as inherited param and to the pre-trained model
Browse files Browse the repository at this point in the history
  • Loading branch information
gcroci2 committed Oct 23, 2023
1 parent 5280ece commit 7fcc033
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 4 additions & 3 deletions deeprank2/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def __init__( # pylint: disable=too-many-arguments
for k, v in inspect.signature(self.__init__).parameters.items()
if v.default is not inspect.Parameter.empty
}

self.default_vars["classes_to_index"] = None
self.features = features
self.target_transform = target_transform
self._check_features()
Expand All @@ -510,7 +510,7 @@ def __init__( # pylint: disable=too-many-arguments
Please provide a valid training GridDataset or the path to a valid DeepRank2 pre-trained model.""")

#check inherited parameter with the ones in the training set
inherited_params = ["features", "target", "target_transform", "task", "classes"]
inherited_params = ["features", "target", "target_transform", "task", "classes", "classes_to_index"]
self._check_inherited_params(inherited_params, data)

elif train and train_data:
Expand Down Expand Up @@ -737,6 +737,7 @@ def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-local
for k, v in inspect.signature(self.__init__).parameters.items()
if v.default is not inspect.Parameter.empty
}
self.default_vars["classes_to_index"] = None
self.node_features = node_features
self.edge_features = edge_features
self.clustering_method = clustering_method
Expand Down Expand Up @@ -769,7 +770,7 @@ def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-local
Please provide a valid training GraphDataset or the path to a valid DeepRank2 pre-trained model.""")

#check inherited parameter with the ones in the training set
inherited_params = ["node_features", "edge_features", "features_transform", "target", "target_transform", "task", "classes"]
inherited_params = ["node_features", "edge_features", "features_transform", "target", "target_transform", "task", "classes", "classes_to_index"]
self._check_inherited_params(inherited_params, data)

elif train and train_data:
Expand Down
2 changes: 2 additions & 0 deletions deeprank2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,7 @@ def _load_params(self):
self.target_transform = state["target_transform"]
self.task = state["task"]
self.classes = state["classes"]
self.classes_to_index = state["classes_to_index"]
self.class_weights = state["class_weights"]
self.batch_size_train = state["batch_size_train"]
self.batch_size_test = state["batch_size_test"]
Expand Down Expand Up @@ -917,6 +918,7 @@ def _save_model(self):
"target_transform": self.target_transform,
"task": self.task,
"classes": self.classes,
"classes_to_index": self.classes_to_index,
"class_weights": self.class_weights,
"batch_size_train": self.batch_size_train,
"batch_size_test": self.batch_size_test,
Expand Down

0 comments on commit 7fcc033

Please sign in to comment.