Skip to content

Commit

Permalink
do not fuse model again for a QAT model
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #643

For a QAT model, it contains observers. After QAT training, those observers already contain updated statistics, such as min_val, max_val.

When we want to export FP32 QAT model for a sanity check, if we call **fuse_utils.fuse_model()** again (which is often already called when we build the QAT model before QAT training), it will remove statistics in the observers.

Reviewed By: wat3rBro

Differential Revision: D52152688

fbshipit-source-id: 08aa16f2aa72b3809e0ba2d346f1b806c0e6ede7
  • Loading branch information
stephenyan1231 authored and facebook-github-bot committed Dec 15, 2023
1 parent da53aa1 commit 8f13023
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions d2go/export/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@ def _convert_fp_model(
cfg: CfgNode, pytorch_model: nn.Module, data_loader: Iterable
) -> nn.Module:
"""Converts floating point predictor"""
pytorch_model = fuse_utils.fuse_model(pytorch_model)
logger.info(f"Fused Model:\n{pytorch_model}")
if fuse_utils.count_bn_exist(pytorch_model) > 0:
logger.warning("BN existed in pytorch model after fusing.")
if not isinstance(cfg, CfgNode) or (not cfg.QUANTIZATION.QAT.ENABLED):
# Do not fuse model again for QAT model since it will remove observer statistics (e.g. min_val, max_val)
pytorch_model = fuse_utils.fuse_model(pytorch_model)
logger.info(f"Fused Model:\n{pytorch_model}")
if fuse_utils.count_bn_exist(pytorch_model) > 0:
logger.warning("BN existed in pytorch model after fusing.")
return pytorch_model


Expand Down

0 comments on commit 8f13023

Please sign in to comment.