diff --git a/d2go/quantization/modeling.py b/d2go/quantization/modeling.py index 1807a81d..ffde9809 100644 --- a/d2go/quantization/modeling.py +++ b/d2go/quantization/modeling.py @@ -22,7 +22,6 @@ from mobile_cv.arch.utils import fuse_utils from mobile_cv.common.misc.iter_utils import recursive_iterate from torch import nn -from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import ( convert_pt2e, prepare_pt2e, @@ -32,6 +31,7 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) +from torch.export import export_for_training TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) # some tests still import prepare/convert from below. So don't remove these. @@ -39,8 +39,7 @@ from torch.ao.quantization.quantize import convert from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx else: - from torch.quantization.quantize import convert - from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx + pass logger = logging.getLogger(__name__) @@ -368,7 +367,7 @@ def prepare_fake_quant_model(cfg, model, is_qat, example_input=None): ) else: logger.info("Using default pt2e quantization APIs with XNNPACKQuantizer") - captured_model = capture_pre_autograd_graph(model, example_input) + captured_model = export_for_training(model, example_input).module() quantizer = _get_symmetric_xnnpack_quantizer() if is_qat: model = prepare_qat_pt2e(captured_model, quantizer)