diff --git a/tools/exporter.py b/tools/exporter.py index f1dd5d10..53a8db4d 100644 --- a/tools/exporter.py +++ b/tools/exporter.py @@ -45,6 +45,7 @@ def main( device: str = "cpu", compare_accuracy: bool = False, skip_if_fail: bool = False, + skip_model_weights: bool = False, ) -> ExporterOutput: if compare_accuracy: raise NotImplementedError( @@ -54,12 +55,16 @@ def main( # ret["accuracy_comparison"] = accuracy_comparison cfg = copy.deepcopy(cfg) + with temp_defrost(cfg): + if skip_model_weights: + cfg.merge_from_list(["MODEL.WEIGHTS", ""]) + runner = setup_after_launch(cfg, output_dir, runner_class) with temp_defrost(cfg): cfg.merge_from_list(["MODEL.DEVICE", device]) - model = runner.build_model(cfg, eval_only=True) + model = runner.build_model(cfg, eval_only=True) # NOTE: train dataset is used to avoid leakage since the data might be used for # running calibration for quantization. test_loader is used to make sure it follows # the inference behaviour (augmentation will not be applied). @@ -112,6 +117,7 @@ def run_with_cmdline_args(args): device=args.device, compare_accuracy=args.compare_accuracy, skip_if_fail=args.skip_if_fail, + skip_model_weights=args.skip_model_weights, ) @@ -139,6 +145,8 @@ def get_parser(): help="If set, suppress the exception for failed exporting and continue to" " export the next type of model", ) + parser.add_argument("--skip-model-weights", action="store_true") + return parser