diff --git a/d2go/runner/default_runner.py b/d2go/runner/default_runner.py index a5215baf..f398d693 100644 --- a/d2go/runner/default_runner.py +++ b/d2go/runner/default_runner.py @@ -572,7 +572,18 @@ def do_train(self, cfg, model, resume): # The checkpoint stores the training iteration that just finished, thus we start # at the next iteration (or iter zero if there's no checkpoint). start_iter += 1 - max_iter = cfg.SOLVER.MAX_ITER + + if "EARLY_STOPPING_FRACTION" in cfg.SOLVER: + assert ( + cfg.SOLVER.EARLY_STOPPING_FRACTION >= 0 + ), f"Early stopping fraction must be non-negative, but is {cfg.SOLVER.EARLY_STOPPING_FRACTION}" + assert ( + cfg.SOLVER.EARLY_STOPPING_FRACTION <= 1 + ), f"Early stopping fraction must not be larger than 1, but is {cfg.SOLVER.EARLY_STOPPING_FRACTION}" + max_iter = int(cfg.SOLVER.MAX_ITER * cfg.SOLVER.EARLY_STOPPING_FRACTION) + else: + max_iter = cfg.SOLVER.MAX_ITER + periodic_checkpointer = PeriodicCheckpointer( checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter )