Skip to content

Commit

Permalink
fix datasets for cases in which there is a target attribute but no ta…
Browse files Browse the repository at this point in the history
…rget values are present in the hdf5 file/s
  • Loading branch information
gcroci2 committed Oct 24, 2023
1 parent 060e6bf commit 1d3c0f5
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions deeprank2/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,22 +614,36 @@ def load_one_grid(self, hdf5_path: str, entry_name: str) -> Data:
"""

feature_data = []
target_value = None

with h5py.File(hdf5_path, 'r') as hdf5_file:
entry_group = hdf5_file[entry_name]
grp = hdf5_file[entry_name]

mapped_features_group = entry_group[gridstorage.MAPPED_FEATURES]
mapped_features_group = grp[gridstorage.MAPPED_FEATURES]
for feature_name in self.features:
if feature_name[0] != '_': # ignore metafeatures
feature_data.append(mapped_features_group[feature_name][:])
x=torch.tensor(np.expand_dims(np.array(feature_data), axis=0), dtype=torch.float)

target_value = entry_group[targets.VALUES][self.target][()]
# target
if self.target is None:
y = None
else:
if targets.VALUES in grp and self.target in grp[targets.VALUES]:
y = torch.tensor([grp[targets.VALUES][self.target][()]], dtype=torch.float)

# Wrap up the data in this object, for the collate_fn to handle it properly:
data = Data(x=torch.tensor(np.expand_dims(np.array(feature_data), axis=0), dtype=torch.float),
y=torch.tensor([target_value], dtype=torch.float))
if self.task == targets.REGRESS and self.target_transform is True:
y = torch.sigmoid(torch.log(y))
elif self.task is not targets.REGRESS and self.target_transform is True:
raise ValueError(f"Task is set to {self.task}. Please set it to regress to transform the target with a sigmoid.")
else:
y = None
possible_targets = grp[targets.VALUES].keys()
if self.train:
raise ValueError(f"Target {self.target} missing in entry {entry_name} in file {hdf5_path}, possible targets are {possible_targets}." +
"\n Use the query class to add more target values to input data.")

# Wrap up the data in this object, for the collate_fn to handle it properly:
data = Data(x=x, y=y)
data.entry_names = entry_name

return data
Expand Down Expand Up @@ -948,9 +962,11 @@ def load_one_graph(self, fname: str, entry_name: str) -> Data: # pylint: disabl
raise ValueError(f"Task is set to {self.task}. Please set it to regress to transform the target with a sigmoid.")

else:
y = None
possible_targets = grp[targets.VALUES].keys()
raise ValueError(f"Target {self.target} missing in entry {entry_name} in file {fname}, possible targets are {possible_targets}." +
"\n Use the query class to add more target values to input data.")
if self.train:
raise ValueError(f"Target {self.target} missing in entry {entry_name} in file {fname}, possible targets are {possible_targets}." +
"\n Use the query class to add more target values to input data.")

# positions
pos = torch.tensor(grp[f"{Nfeat.NODE}/{Nfeat.POSITION}/"][()], dtype=torch.float).contiguous()
Expand Down

0 comments on commit 1d3c0f5

Please sign in to comment.