Skip to content

Commit

Permalink
fix Trainer _eval method for cases in which there is a target attribu…
Browse files Browse the repository at this point in the history
…te but no target values are present in the hdf5 file/s
  • Loading branch information
gcroci2 committed Oct 24, 2023
1 parent 1d3c0f5 commit 4d588e1
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion deeprank2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,9 @@ def _eval( # pylint: disable=too-many-locals
loss_ = loss_func(pred, y)
count_predictions += pred.shape[0]
sum_of_losses += loss_.detach().item() * pred.shape[0]
else:
target_vals = ['None'] * pred.shape[0]
eval_loss = 'None'

# Get the outputs for export
# Remember that non-linear activation is automatically applied in CrossEntropyLoss
Expand All @@ -764,7 +767,7 @@ def _eval( # pylint: disable=too-many-locals
if count_predictions > 0:
eval_loss = sum_of_losses / count_predictions
else:
eval_loss = 0.0
eval_loss = 'None'

self._output_exporters.process(
pass_name, epoch_number, entry_names, outputs, target_vals, eval_loss)
Expand Down

0 comments on commit 4d588e1

Please sign in to comment.