From b49b8f3b9da372fe5e5da0423e8d55619bd187b9 Mon Sep 17 00:00:00 2001 From: Zhicheng Yan Date: Tue, 31 Oct 2023 19:11:02 -0700 Subject: [PATCH] allow to skip loading model weights in build_model() Summary: Currently, in runner **build_model()** method, when **eval_only=True**, we always try to load model weights. This is quite restricted in some cases. For example, we may just wanna build a model in eval mode to profile its efficiency, and we have not trained the model or generated the model weights in a checkpoint file. Thus, this diff adds an argument **skip_model_weights** to allow users to skip the loading of model weights. Note, this diff is entirely back-compatible and is NOT expected to break existing implementations. Reviewed By: wat3rBro Differential Revision: D50623772 --- tools/exporter.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tools/exporter.py b/tools/exporter.py index f1dd5d10..53a8db4d 100644 --- a/tools/exporter.py +++ b/tools/exporter.py @@ -45,6 +45,7 @@ def main( device: str = "cpu", compare_accuracy: bool = False, skip_if_fail: bool = False, + skip_model_weights: bool = False, ) -> ExporterOutput: if compare_accuracy: raise NotImplementedError( @@ -54,12 +55,16 @@ def main( # ret["accuracy_comparison"] = accuracy_comparison cfg = copy.deepcopy(cfg) + with temp_defrost(cfg): + if skip_model_weights: + cfg.merge_from_list(["MODEL.WEIGHTS", ""]) + runner = setup_after_launch(cfg, output_dir, runner_class) with temp_defrost(cfg): cfg.merge_from_list(["MODEL.DEVICE", device]) - model = runner.build_model(cfg, eval_only=True) + model = runner.build_model(cfg, eval_only=True) # NOTE: train dataset is used to avoid leakage since the data might be used for # running calibration for quantization. test_loader is used to make sure it follows # the inference behaviour (augmentation will not be applied). @@ -112,6 +117,7 @@ def run_with_cmdline_args(args): device=args.device, compare_accuracy=args.compare_accuracy, skip_if_fail=args.skip_if_fail, + skip_model_weights=args.skip_model_weights, ) @@ -139,6 +145,8 @@ def get_parser(): help="If set, suppress the exception for failed exporting and continue to" " export the next type of model", ) + parser.add_argument("--skip-model-weights", action="store_true") + return parser