Skip to content

Commit

Permalink
add callbacks for inference_on_dataset
Browse files Browse the repository at this point in the history
Differential Revision: D51540498
  • Loading branch information
Yanghan Wang authored and facebook-github-bot committed Nov 23, 2023
1 parent 87649f4 commit 12f28a1
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion d2go/runner/default_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,16 @@ def _create_evaluators(
)
return evaluator

# experimental API
@classmethod
def _get_inference_callbacks(cls):
return {
"on_start": lambda: None,
"on_end": lambda: None,
"before_inference": lambda: None,
"after_inference": lambda: None,
}

def _do_test(self, cfg, model, train_iter=None, model_tag="default"):
"""train_iter: Current iteration of the model, None means final iteration"""
assert len(cfg.DATASETS.TEST)
Expand Down Expand Up @@ -430,7 +440,10 @@ def _get_inference_dir_name(base_dir, inference_type, dataset_name):
else model,
)

results_per_dataset = inference_on_dataset(model, data_loader, evaluator)
inference_callbacks = self._get_inference_callbacks()
results_per_dataset = inference_on_dataset(
model, data_loader, evaluator, callbacks=inference_callbacks
)

if comm.is_main_process():
results[model_tag][dataset_name] = results_per_dataset
Expand Down

0 comments on commit 12f28a1

Please sign in to comment.