Skip to content

Commit

Permalink
Convert modules to qmodules at the beginning of QuantizationSimModel.…
Browse files Browse the repository at this point in the history
…__init__

Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu committed Oct 14, 2024
1 parent 14812ea commit 2bfa8e8
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,10 @@ def _is_recursive_parsing_needed(self, module: torch.nn.Module,
:param trace: torch.jit trace of the module
:return: Boolean whether recursive parsing needed or not. If needed returns True, False otherwise.
"""
from aimet_torch.v2.nn import BaseQuantizationMixin
if isinstance(module, BaseQuantizationMixin):
return self._is_recursive_parsing_needed(module.get_original_module(), trace)

recursive_parsing_needed = True
if is_torch_nn_leaf_module(module) or \
is_custom_leaf_module(module, self._find_aten_nodes_in_forward_pass(trace)) or \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

from aimet_common.defs import QuantScheme, QuantizationDataType, MAP_ROUND_MODE_TO_PYMO
from aimet_common.utils import AimetLogger, log_with_error_and_assert_if_false
from aimet_torch.utils import get_v1_quant_scheme_for_initialization
from aimet_torch.utils import get_v1_quant_scheme_for_initialization, is_leaf_module
from aimet_torch.qc_quantize_op import QcQuantizeOpMode, QcQuantizeWrapper, StaticGridQuantWrapper, tensor_quantizer_factory
from aimet_torch.tensor_quantizer import TensorQuantizer, StaticGridPerChannelQuantizer
import aimet_torch.fp_quantization as v1_fp_quantization
Expand Down Expand Up @@ -92,7 +92,19 @@ def __init__(self, module_to_wrap: torch.nn.Module, weight_bw: int, activation_b

# Create quantizer for each parameter and compute encodings
self.param_quantizers = {}
for name, param in module_to_wrap.named_parameters():

from aimet_torch.v2.nn import BaseQuantizationMixin
if isinstance(module_to_wrap, BaseQuantizationMixin):
# NOTE: AIMET v2 qmodule always only quantizes the paramters that it directly owns
recurse = False
else:
# NOTE: This is only for backwards-compatibility with v1 quant wrapper
# which sometimes tries to quantize not only the parameters it directly owns
# but also all the parameters of its submodules in some edge cases
assert is_leaf_module(module_to_wrap)
recurse = True

for name, param in module_to_wrap.named_parameters(recurse=recurse):
logger.debug("Adding quantizer for parameter: %s", name)
self.param_quantizers[name] = LazyParamQuantizer(weight_bw,
rounding_mode,
Expand All @@ -117,8 +129,7 @@ def enable_per_channel_quantization(self):
"""
Changes all parameter quantizers (if any) to per-channel mode.
"""
for param_name, _ in self._module_to_wrap.named_parameters():
param_quantizer = self.param_quantizers[param_name]
for param_name, param_quantizer in self.param_quantizers.items():
channel_axis = 0
if isinstance(self._module_to_wrap, (torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
Expand Down Expand Up @@ -199,10 +210,8 @@ def realize_v2_wrapper(self):
from aimet_torch.v2.nn import QuantizationMixin
from aimet_torch.v2.nn.fake_quant import _legacy_impl

if type(self._module_to_wrap) in QuantizationMixin.cls_to_qcls: # pylint: disable=unidiomatic-typecheck
quantized_module = QuantizationMixin.from_module(self._module_to_wrap)
else:
quantized_module = _legacy_impl.FakeQuantizationMixin.from_module(self._module_to_wrap)
assert isinstance(self._module_to_wrap, (QuantizationMixin, _legacy_impl.FakeQuantizationMixin))
quantized_module = self._module_to_wrap

# For unused modules, quantsim assumes # inputs = # outputs = 1
# If this is incorrect, propagate the configuration of the last input/output quantizers to the remaining
Expand Down
2 changes: 1 addition & 1 deletion TrainingExtensions/torch/src/python/aimet_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ def record_tensor_shape(module, inputs, outputs):

inout_tensor_shape_map[module] = (input_tensor_shape_list, output_tensor_shape_list)

run_hook_for_layers_with_given_input(model, input_tensor, record_tensor_shape)
run_hook_for_layers_with_given_input(model, input_tensor, record_tensor_shape, leaf_node_only=False)
return inout_tensor_shape_map


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,16 +315,12 @@ def _realize_quant_wrappers_in_model(self, model: torch.nn.Module):
"""
for module_name, module_ref in model.named_children():
if isinstance(module_ref, LazyQuantizeWrapper):
quantized_module = self._realize_quant_wrapper(module_ref)
quantized_module = module_ref.realize_v1_wrapper()
setattr(model, module_name, quantized_module)

elif not utils.is_leaf_module(module_ref):
self._realize_quant_wrappers_in_model(module_ref)

@staticmethod
def _realize_quant_wrapper(module: torch.nn.Module) -> QcQuantizeWrapper:
return module.realize_v1_wrapper()

def get_supported_kernels(self) -> Dict:
"""
Return _supported_kernels parsed from the config file
Expand Down
19 changes: 19 additions & 0 deletions TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,27 @@ def implements(cls, module_cls):
"""

def wrapper(quantized_cls):
# pylint: disable=import-outside-toplevel
cls.cls_to_qcls[module_cls] = quantized_cls
cls.qcls_to_cls[quantized_cls] = module_cls

# Update the mapping from torch module to onnx op
# so v1 connected graph and quantsim configurator can properly handle quantized modules.
from aimet_torch.onnx_utils import map_torch_types_to_onnx
onnx_type = map_torch_types_to_onnx.get(module_cls, None)
if onnx_type:
map_torch_types_to_onnx[quantized_cls] = onnx_type

# Update the mapping from torch module to backend op
# so v1 connected graph and quantsim configurator can properly handle quantized modules.
# TODO: This unfortunately relies on the **class name** of the module, not the real type
# of the module due to the limitation of v1 implementation.
# Should redefine `aimet_to_to_backend_op_name_map` as `Dict[Type[Module], str]`
from aimet_torch.translation_mapping import aimet_op_to_backend_op_name_map
backend_op_name = aimet_op_to_backend_op_name_map.get(module_cls.__name__, None)
if backend_op_name:
aimet_op_to_backend_op_name_map[quantized_cls.__name__] = backend_op_name

return quantized_cls

return wrapper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# =============================================================================
""" Top level API for performing quantization simulation of a pytorch model """

import copy
from typing import Union, Tuple, Optional
import warnings
import itertools
Expand All @@ -44,11 +45,10 @@
import torch

from aimet_common.defs import QuantScheme, QuantizationDataType
from aimet_torch.v1.quantsim import QuantizationSimModel as V1QuantizationSimModel, logger
import aimet_torch.v1.quantsim as quantsim_v1
from aimet_torch.v1.quantsim import QuantizationSimModel as V1QuantizationSimModel, logger, unquantizable_modules, quantized_modules
from aimet_torch.v2 import nn as aimet_nn
from aimet_torch.v2.nn import QuantizationMixin
from aimet_torch.v2.nn import BaseQuantizationMixin
from aimet_torch.v2.nn import BaseQuantizationMixin, QuantizationMixin
from aimet_torch.v2.nn.fake_quant import _legacy_impl
from aimet_torch.quantsim_config.builder import LazyQuantizeWrapper
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantization.affine import AffineQuantizerBase
Expand All @@ -59,11 +59,36 @@
from aimet_torch.v2.deepspeed_utils import _register_zero3_forward_hooks


qc_quantize_modules_dict = {
torch.nn.RNN: LazyQuantizeWrapper,
torch.nn.LSTM: LazyQuantizeWrapper,
torch.nn.GRU: LazyQuantizeWrapper,
}
unquantizable_modules = (QuantizerBase, *unquantizable_modules)
quantized_modules = (BaseQuantizationMixin, *quantized_modules)
containers = (
torch.nn.Container,
torch.nn.Sequential,
torch.nn.ModuleList,
torch.nn.ModuleDict,
torch.nn.ParameterList,
torch.nn.ParameterDict,
)


def _convert_to_qmodule(module: torch.nn.Module):
"""
Helper function to convert all modules to quantized aimet.nn modules.
"""
if not isinstance(module, (*quantized_modules, *unquantizable_modules, *containers)):
try:
module = QuantizationMixin.from_module(module)
except RuntimeError as e:
try:
module = _legacy_impl.FakeQuantizationMixin.from_module(module)
except RuntimeError:
if not tuple(module.children()):
raise e # pylint: disable=raise-missing-from

for name, child in module.named_children():
setattr(module, name, _convert_to_qmodule(child))

return module


class QuantizationSimModel(V1QuantizationSimModel):
Expand Down Expand Up @@ -98,6 +123,19 @@ def __init__(self, # pylint: disable=too-many-arguments, too-many-locals
else:
raise TypeError("'rounding_mode' parameter is no longer supported.")

for module in model.modules():
if isinstance(module, BaseQuantizationMixin):
raise RuntimeError

if isinstance(module, QuantizerBase):
raise RuntimeError

if not in_place:
model = copy.deepcopy(model)
in_place = True

model = _convert_to_qmodule(model)

with _register_zero3_forward_hooks(model, use_dummy_params=True):
# NOTE: Register for the model is pre-partitioned by deepspeed zero3 or zero3-offload.
# Pre-partitioned models aren't runnable as-is, but are needed to to be initialized
Expand Down Expand Up @@ -145,16 +183,6 @@ def __init__(self, # pylint: disable=too-many-arguments, too-many-locals
# Set quantization parameters to the device of the original module
module.to(device=device)

@staticmethod
def _realize_quant_wrapper(module: LazyQuantizeWrapper) -> BaseQuantizationMixin:
"""
Make wrapper builder into v2 quant wrapper
:param module: wrapper builder to realize
:return: realized v2 quant wrapper
"""
return module.realize_v2_wrapper()

def compute_encodings(self, forward_pass_callback, forward_pass_callback_args):
"""
Computes encodings for all quantization sim nodes in the model. It is also used to find initial encodings for
Expand Down Expand Up @@ -220,11 +248,6 @@ def concretize_block_size(qtzr, inp):
for h in handles:
h.remove()

def _create_quantizer_module(self, *args, **kwargs): # pylint: disable=arguments-differ
# RNN, LSTM, and GRU don't require special handling in aimet V2
with patch_attr(quantsim_v1, 'qc_quantize_modules_dict', qc_quantize_modules_dict):
return super()._create_quantizer_module(*args, **kwargs)

def set_percentile_value(self, percentile_value: float):
"""
Set the percentile value to be used while computing encodings
Expand Down Expand Up @@ -299,12 +322,20 @@ def _apply_qdq_to_model_parameters(cls, model: torch.nn.Module):
def quant_wrappers(self): # pylint: disable=missing-docstring
return super().quant_wrappers()

@classmethod
def _is_quantizable_module(cls, module: torch.nn.Module):
return super()._is_quantizable_module(module) and\
not isinstance(module, QuantizerBase)

@classmethod
def _is_quantized_module(cls, module: torch.nn.Module):
return super()._is_quantized_module(module) or\
isinstance(module, BaseQuantizationMixin)
# Overrides V1QuantizationSimModel._add_quantization_wrappers
def _add_quantization_wrappers(self, module, num_inout_tensors, default_data_type):
# pylint: disable=protected-access
for name, child in module.named_children():
if isinstance(child, BaseQuantizationMixin):
child_wrapper = self._create_quantizer_module(child, num_inout_tensors, default_data_type)
setattr(module, name, child_wrapper)
child = child_wrapper._module_to_wrap
self._add_quantization_wrappers(child, num_inout_tensors, default_data_type)

# Overrides V1QuantizationSimModel._realize_quant_wrappers_in_model
def _realize_quant_wrappers_in_model(self, model: torch.nn.Module):
for name, child in model.named_children():
if isinstance(child, LazyQuantizeWrapper):
child = child.realize_v2_wrapper()
setattr(model, name, child)
self._realize_quant_wrappers_in_model(child)
Original file line number Diff line number Diff line change
Expand Up @@ -656,10 +656,10 @@ def forward(self, *inputs):
return self.layers[2](inputs[0])

model = Net()
sim = QuantizationSimModel(model, dummy_input=torch.rand(1, 1, 12, 12),
quant_scheme=QuantScheme.post_training_tf)
with pytest.raises(RuntimeError):
sim = QuantizationSimModel(model, dummy_input=torch.rand(1, 1, 12, 12),
quant_scheme=QuantScheme.post_training_tf)

self.verify_quantization_wrappers(model, sim.model)

def test_add_quantization_wrappers_with_modulelist_two_deep(self):
"""With a two-deep model using ModuleList"""
Expand Down Expand Up @@ -691,10 +691,9 @@ def forward(self, *inputs):
return self.layers[2](inputs[0])

model = Net()
sim = QuantizationSimModel(model, dummy_input=torch.rand(1, 3, 12, 12),
quant_scheme=QuantScheme.post_training_tf)

self.verify_quantization_wrappers(model, sim.model)
with pytest.raises(RuntimeError):
sim = QuantizationSimModel(model, dummy_input=torch.rand(1, 3, 12, 12),
quant_scheme=QuantScheme.post_training_tf)

def test_add_quantization_wrappers_with_modulelist_with_layers_to_ignore(self):
"""With a two-deep model using ModuleList and layers_to_ignore"""
Expand Down

0 comments on commit 2bfa8e8

Please sign in to comment.