Skip to content

Commit

Permalink
allow to skip loading model weights in build_model()
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #630

Currently, in runner **build_model()** method, when **eval_only=True**, we always try to load model weights.
This is quite restricted in some cases. For example, we may just wanna build a model in eval mode to profile its efficiency, and we have not trained the model or generated the model weights in a checkpoint file.

Thus, this diff adds an argument **skip_model_weights** to allow users to skip the loading of model weights.
Note, this diff is entirely back-compatible and is NOT expected to break existing implementations.

Reviewed By: navsud, wat3rBro

Differential Revision: D50623772

fbshipit-source-id: 282dc6f19e17a4dd9eb0048e068c5299bb3d47c2
  • Loading branch information
stephenyan1231 authored and facebook-github-bot committed Nov 5, 2023
1 parent 2d4d2f2 commit f2a0c52
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tools/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def main(
device: str = "cpu",
compare_accuracy: bool = False,
skip_if_fail: bool = False,
skip_model_weights: bool = False,
) -> ExporterOutput:
if compare_accuracy:
raise NotImplementedError(
Expand All @@ -54,12 +55,16 @@ def main(
# ret["accuracy_comparison"] = accuracy_comparison

cfg = copy.deepcopy(cfg)
with temp_defrost(cfg):
if skip_model_weights:
cfg.merge_from_list(["MODEL.WEIGHTS", ""])

runner = setup_after_launch(cfg, output_dir, runner_class)

with temp_defrost(cfg):
cfg.merge_from_list(["MODEL.DEVICE", device])
model = runner.build_model(cfg, eval_only=True)

model = runner.build_model(cfg, eval_only=True)
# NOTE: train dataset is used to avoid leakage since the data might be used for
# running calibration for quantization. test_loader is used to make sure it follows
# the inference behaviour (augmentation will not be applied).
Expand Down Expand Up @@ -112,6 +117,7 @@ def run_with_cmdline_args(args):
device=args.device,
compare_accuracy=args.compare_accuracy,
skip_if_fail=args.skip_if_fail,
skip_model_weights=args.skip_model_weights,
)


Expand Down Expand Up @@ -139,6 +145,8 @@ def get_parser():
help="If set, suppress the exception for failed exporting and continue to"
" export the next type of model",
)
parser.add_argument("--skip-model-weights", action="store_true")

return parser


Expand Down

0 comments on commit f2a0c52

Please sign in to comment.