Skip to content

Commit

Permalink
add callbacks for inference_on_dataset
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #637

Reviewed By: tglik

Differential Revision: D51540498
  • Loading branch information
Yanghan Wang authored and facebook-github-bot committed Nov 30, 2023
1 parent 87649f4 commit 18e92ee
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
5 changes: 3 additions & 2 deletions d2go/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,23 @@ def inference_on_dataset(
model: torch.nn.Module,
data_loader: Iterable,
evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None],
**kwargs,
):
"""
A drop-in replacement for d2's inference_on_dataset to run inference on datasets,
supports customization for checkpointing
* has_finished_process(self) -> bool: return True if `self.process()` could be skipped
"""
if evaluator is None:
return inference_on_dataset_d2(model, data_loader, evaluator)
return inference_on_dataset_d2(model, data_loader, evaluator, **kwargs)

if isinstance(evaluator, abc.MutableSequence):
evaluator = DatasetEvaluators(evaluator)

if not (
hasattr(evaluator, "has_finished_process") and evaluator.has_finished_process()
):
return inference_on_dataset_d2(model, data_loader, evaluator)
return inference_on_dataset_d2(model, data_loader, evaluator, **kwargs)

evaluator.reset()
results = evaluator.evaluate()
Expand Down
15 changes: 14 additions & 1 deletion d2go/runner/default_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,16 @@ def _create_evaluators(
)
return evaluator

# experimental API
@classmethod
def _get_inference_callbacks(cls):
return {
"on_start": lambda: None,
"on_end": lambda: None,
"before_inference": lambda: None,
"after_inference": lambda: None,
}

def _do_test(self, cfg, model, train_iter=None, model_tag="default"):
"""train_iter: Current iteration of the model, None means final iteration"""
assert len(cfg.DATASETS.TEST)
Expand Down Expand Up @@ -430,7 +440,10 @@ def _get_inference_dir_name(base_dir, inference_type, dataset_name):
else model,
)

results_per_dataset = inference_on_dataset(model, data_loader, evaluator)
inference_callbacks = self._get_inference_callbacks()
results_per_dataset = inference_on_dataset(
model, data_loader, evaluator, callbacks=inference_callbacks
)

if comm.is_main_process():
results[model_tag][dataset_name] = results_per_dataset
Expand Down

0 comments on commit 18e92ee

Please sign in to comment.