diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/__init__.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/__init__.py index c11980c8a44..5b7cbace936 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/__init__.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/__init__.py @@ -36,4 +36,24 @@ # ============================================================================= # pylint: disable=missing-docstring +import contextlib +import torch from .fake_quant import * + + +@contextlib.contextmanager +def compute_encodings(model: torch.nn.Module): + """ + Compute encodings of all quantized modules in the model + """ + with contextlib.ExitStack() as stack: + for module in model.modules(): + if isinstance(module, BaseQuantizationMixin): + ctx = module.compute_encodings() + stack.enter_context(ctx) + + if isinstance(model, BaseQuantizationMixin): + ctx = model.compute_encodings() + stack.enter_context(ctx) + + yield diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/fake_quant.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/fake_quant.py index 12011326698..a66e60efa77 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/fake_quant.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/fake_quant.py @@ -37,7 +37,7 @@ """Fake-quantized modules""" from collections import OrderedDict -from typing import Type, Optional, Tuple +from typing import Type, Optional, Tuple, List, Dict from torch import Tensor import torch.nn as nn @@ -45,16 +45,82 @@ from torch.utils._pytree import tree_map from aimet_torch.experimental.v2.nn.quant_base import BaseQuantizationMixin +from aimet_torch.experimental.v2.quantization.modules.quantize import _QuantizerBase import aimet_torch.elementwise_ops as aimet_ops + +def _flatten_nn_module_list(module): + """ + Flatten nested list of nn.Modules into a flat list + """ + def flat_iter(mod): + if isinstance(mod, (list, tuple, nn.ModuleList)): + for x in mod: + yield from flat_iter(x) + else: + yield mod + + return list(flat_iter(module)) + + class FakeQuantizationMixin(BaseQuantizationMixin): # pylint: disable=abstract-method """ Mixin that implements fake-quantization on top of regular pytorch modules. """ - # Mapping from a base module class to quantized module class - quantized_classes_map = OrderedDict() + cls_to_qcls = OrderedDict() # ouantized class -> original class + qcls_to_cls = OrderedDict() # original class -> quantized class + + def export_input_encodings(self) -> List[List[Dict]]: + """ + Returns a list of input encodings, each represented as a List of Dicts + """ + return [ + quantizer.get_encodings() if isinstance(quantizer, _QuantizerBase) else None + for quantizer in _flatten_nn_module_list(self.input_quantizers) + ] + + def export_output_encodings(self) -> List[List[Dict]]: + """ + Returns a list of output encodings, each represented as a List of Dicts + """ + return [ + quantizer.get_encodings() if isinstance(quantizer, _QuantizerBase) else None + for quantizer in _flatten_nn_module_list(self.output_quantizers) + ] + + def export_param_encodings(self) -> Dict[str, List[Dict]]: + """ + Returns a dict of {param name: param encodings}, with each encoding represented as a List of Dicts + """ + return { + param_name: quantizer.get_encodings() if isinstance(quantizer, _QuantizerBase) else None + for param_name, quantizer in self.param_quantizers.items() + } + + def get_original_module(self) -> nn.Module: + """ + Returns the floating point version of quantized module + """ + # pylint: disable=protected-access + + qtzn_module_cls = type(self) + orig_module_cls = self.qcls_to_cls.get(qtzn_module_cls) + + orig_module = self.__new__(orig_module_cls) + orig_module.__dict__ = self.__dict__.copy() + del orig_module.__dict__['_forward'] + del orig_module.__dict__['forward'] + + orig_module._parameters = self._parameters.copy() + orig_module._buffers = self._buffers.copy() + orig_module._modules = self._modules.copy() + del orig_module._modules['input_quantizers'] + del orig_module._modules['output_quantizers'] + del orig_module._modules['param_quantizers'] + + return orig_module @classmethod def wrap(cls, module_cls: Type[nn.Module]) -> Type[nn.Module]: @@ -64,8 +130,8 @@ def wrap(cls, module_cls: Type[nn.Module]) -> Type[nn.Module]: if not issubclass(module_cls, nn.Module): raise ValueError("Expected module_cls to be a subclass of torch.nn.Module. " f"Got {module_cls}.") - if module_cls in cls.quantized_classes_map: - return cls.quantized_classes_map[module_cls] + if module_cls in cls.cls_to_qcls: + return cls.cls_to_qcls[module_cls] quantized_cls_name = f"FakeQuantized{module_cls.__name__}" base_classes = (cls, module_cls) @@ -78,7 +144,8 @@ def implements(cls, module_cls): Decorator for registering fake-quantized implementation of the given base class. """ def wrapper(quantized_cls): - cls.quantized_classes_map[module_cls] = quantized_cls + cls.cls_to_qcls[module_cls] = quantized_cls + cls.qcls_to_cls[quantized_cls] = module_cls return quantized_cls return wrapper diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/quant_base.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/quant_base.py index 99d0d897b5d..8e1a0ca4524 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/quant_base.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/quant_base.py @@ -137,7 +137,7 @@ def from_module(cls, module: nn.Module): """ # pylint: disable=protected-access module_cls = type(module) - qtzn_module_cls = cls.quantized_classes_map.get(module_cls, None) + qtzn_module_cls = cls.cls_to_qcls.get(module_cls, None) if not qtzn_module_cls: raise RuntimeError( diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/modules/quantize.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/modules/quantize.py index 9b60ff6c9a9..02484f9424c 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/modules/quantize.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/modules/quantize.py @@ -38,7 +38,7 @@ """ nn.Modules for quantization operators """ import copy -from typing import Optional, Tuple +from typing import Optional, Tuple, List, Dict import contextlib from collections import OrderedDict import functools @@ -226,6 +226,31 @@ def get_offset(self) -> Optional[torch.Tensor]: return offset + @torch.no_grad() + def get_encodings(self) -> Optional[List[Dict]]: + """ + Returns a list of encodings, each represented as a List of Dicts + """ + # pylint: disable=redefined-builtin + + if not self.is_initialized(): + return None + + min = self.get_min().flatten() + max = self.get_max().flatten() + scale = self.get_scale().flatten() + offset = self.get_offset().flatten() + bitwidth = self.bitwidth + dtype = "int" + is_symmetric = self.symmetric + + return [ + {'min': float(min_), 'max': float(max_), + 'scale': float(scale_), 'offset': float(offset_), + 'bitwidth': bitwidth, 'dtype': dtype, 'is_symmetric': str(is_symmetric)} + for min_, max_, scale_, offset_ in zip(min, max, scale, offset) + ] + @contextlib.contextmanager def compute_encodings(self): """ diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantsim.py index 9f2aeaac732..142d8f7733a 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantsim.py @@ -36,11 +36,10 @@ # ============================================================================= """ Top level API for performing quantization simulation of a pytorch model """ -import contextlib - import torch from aimet_torch.quantsim import QuantizationSimModel as V1QuantizationSimModel +from aimet_torch.experimental.v2 import nn as aimet_nn from aimet_torch.experimental.v2.nn.fake_quant import FakeQuantizationMixin from aimet_torch.experimental.v2.quantization.wrappers.builder import LazyQuantizeWrapper from aimet_torch import utils @@ -78,10 +77,5 @@ def compute_encodings(self, forward_pass_callback, forward_pass_callback_args): """ # Run forward iterations so we can collect statistics to compute the appropriate encodings with utils.in_eval_mode(self.model), torch.no_grad(): - with contextlib.ExitStack() as stack: - for module in self.model.modules(): - if not isinstance(module, FakeQuantizationMixin): - continue - stack.enter_context(module.compute_encodings()) - + with aimet_nn.compute_encodings(self.model): _ = forward_pass_callback(self.model, forward_pass_callback_args) diff --git a/TrainingExtensions/torch/test/python/experimental/v2/nn/test_custom_op.py b/TrainingExtensions/torch/test/python/experimental/v2/nn/test_custom_op.py index da70f05eaac..dc8b9aedafd 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/nn/test_custom_op.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/nn/test_custom_op.py @@ -67,7 +67,7 @@ def quantized_forward(self, x): finally: # Unregister CustomOp so as not to affect other test functions - FakeQuantizationMixin.quantized_classes_map.pop(CustomOp) + FakeQuantizationMixin.cls_to_qcls.pop(CustomOp) def test_custom_op_wrap_registered(self): try: @@ -85,4 +85,4 @@ def quantized_forward(self, x): finally: # Unregister CustomOp so as not to affect other test functions - FakeQuantizationMixin.quantized_classes_map.pop(CustomOp) + FakeQuantizationMixin.cls_to_qcls.pop(CustomOp) diff --git a/TrainingExtensions/torch/test/python/experimental/v2/test_quantsim_v1_export.py b/TrainingExtensions/torch/test/python/experimental/v2/test_quantsim_v1_export.py index 95c714293f4..10149b75c40 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/test_quantsim_v1_export.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/test_quantsim_v1_export.py @@ -36,71 +36,28 @@ # ============================================================================= import tempfile -import pytest import torch.nn import copy import os import json -from aimet_torch.experimental.v2.quantization.wrappers.quantization_mixin import _QuantizationMixin +# from aimet_torch.experimental.v2.quantization.wrappers.quantization_mixin import _QuantizationMixin +import aimet_torch.experimental.v2.nn as aimet_nn +from aimet_torch.experimental.v2.nn.fake_quant import FakeQuantizationMixin +from aimet_torch.experimental.v2.quantization.modules.quantize import QuantizeDequantize +from aimet_torch.experimental.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer from aimet_torch.elementwise_ops import Add from aimet_torch import onnx_utils from aimet_torch.quantsim import QuantizationSimModel, OnnxExportApiArgs -from models_.models_to_test import SimpleConditional, ModelWithTwoInputs, ModelWith5Output, SoftMaxAvgPoolModel - -# Key/values don't matter -dummy_encoding = {"min": 0, - "max": 2, - "scale": 2/255, - "offset": 0, - "bitwidth": 8, - "dtype": "int", - "is_symmetric": "False"} - - -class DummyMixin(_QuantizationMixin, torch.nn.Module): - """ Dummy class for testing QuantSim export logic """ - - def __init__(self, module, num_inputs, num_outputs, has_input_encodings, has_output_encodings): - super(DummyMixin, self).__init__() - # Assign a dummy output quantizer (since a real mixin will have child quantizers) - self.output_quantizer = torch.nn.Identity() - # Hide module inside list so it doesnt show up as a child (We will not actually have a wrapped module) - self.module = [copy.deepcopy(module)] - self._parameters = self.module[0]._parameters - self.num_inputs = num_inputs - self.num_outputs = num_outputs - self.has_input_encodings = has_input_encodings - self.has_output_encodings = has_output_encodings - self.dummy_encoding = copy.deepcopy(dummy_encoding) - - @classmethod - def from_module(cls, module: torch.nn.Module, num_inputs=1, num_outputs=1, has_input_encodings=False, has_output_encodings=True): - return cls(module, num_inputs, num_outputs, has_input_encodings, has_output_encodings) - - def forward(self, *inputs): - return self.output_quantizer(self.module[0](*inputs)) - - def export_input_encodings(self): - enc = [self.dummy_encoding] if self.has_input_encodings else None - return [enc] * self.num_inputs - - def export_output_encodings(self): - enc = [self.dummy_encoding] if self.has_output_encodings else None - return [enc] * self.num_outputs - - def export_param_encodings(self): - enc_dict = {} - for name, param in self.module[0].named_parameters(): - if name == "weight": - enc_dict[name] = [self.dummy_encoding] * param.shape[0] - else: - enc_dict[name] = None - return enc_dict +from models_.models_to_test import ( + SimpleConditional, + ModelWithTwoInputs, + ModelWith5Output, + ModuleWith5Output, + SoftMaxAvgPoolModel, +) - def get_original_module(self): - return copy.deepcopy(self.module[0]) class DummyModel(torch.nn.Module): @@ -131,10 +88,32 @@ def test_onnx_export(self): model = DummyModel(in_channels=input_shape[1]) sim_model = copy.deepcopy(model) for name, module in sim_model.named_children(): - has_input_encodings = False if name != "conv1" else True - has_output_encodings = True if name != "conv1" else False - num_inputs = 2 if name == "add" else 1 - sim_model.__setattr__(name, DummyMixin.from_module(module, num_inputs, 1, has_input_encodings, has_output_encodings)) + quantized_module = FakeQuantizationMixin.from_module(module) + + if name == "conv1": + input_quantizer = QuantizeDequantize((1,), + bitwidth=8, + symmetric=False, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + quantized_module.input_quantizers[0] = input_quantizer + else: + output_quantizer = QuantizeDequantize((1,), + bitwidth=8, + symmetric=False, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + quantized_module.output_quantizers[0] = output_quantizer + + if hasattr(module, 'weight'): + weight_quantizer = QuantizeDequantize((1,), + bitwidth=4, + symmetric=True, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + quantized_module.param_quantizers['weight'] = weight_quantizer + + setattr(sim_model, name, quantized_module) + + with aimet_nn.compute_encodings(sim_model): + _ = sim_model(torch.randn(input_shape)) with tempfile.TemporaryDirectory() as path: @@ -158,20 +137,38 @@ def test_onnx_export(self): assert set(encoding_dict["activation_encodings"].keys()) == expected_act_keys assert set(encoding_dict["param_encodings"].keys()) == expected_param_keys - for encoding in encoding_dict["activation_encodings"].values(): - assert encoding[0] == dummy_encoding - - - # From: https://github.com/quic/aimet/blob/ce3dafe75d81893cdb8b45ba8abf53a672c28187/TrainingExtensions/torch/test/python/test_quantizer.py#L2731 def test_export_to_onnx_direct(self): model = ModelWithTwoInputs() sim_model = copy.deepcopy(model) dummy_input = (torch.rand(1, 1, 28, 28), torch.rand(1, 1, 28, 28)) - for name, layer in sim_model.named_children(): - has_input_encodings = name == "conv1_a" - wrapped_layer = DummyMixin.from_module(layer, has_input_encodings=has_input_encodings) - setattr(sim_model, name, wrapped_layer) + for name, module in sim_model.named_children(): + quantized_module = FakeQuantizationMixin.from_module(module) + + if name == "conv1_a": + input_quantizer = QuantizeDequantize((1,), + bitwidth=8, + symmetric=False, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + quantized_module.input_quantizers[0] = input_quantizer + + output_quantizer = QuantizeDequantize((1,), + bitwidth=8, + symmetric=False, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + quantized_module.output_quantizers[0] = output_quantizer + + if hasattr(module, 'weight'): + weight_quantizer = QuantizeDequantize((1,), + bitwidth=4, + symmetric=True, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + quantized_module.param_quantizers['weight'] = weight_quantizer + + setattr(sim_model, name, quantized_module) + + with aimet_nn.compute_encodings(sim_model): + _ = sim_model(*dummy_input) with tempfile.TemporaryDirectory() as temp_dir: onnx_utils.EXPORT_TO_ONNX_DIRECT = True @@ -202,12 +199,24 @@ def test_encodings_propagation(self): one onnx node maps to the same torch module """ export_args = OnnxExportApiArgs(opset_version=10, input_names=["input"], output_names=["output"]) - pixel_shuffel = torch.nn.PixelShuffle(2) - model = torch.nn.Sequential(pixel_shuffel) - sim_model = torch.nn.Sequential(DummyMixin.from_module(pixel_shuffel, num_inputs=1, has_input_encodings=True, - has_output_encodings=True)) + pixel_shuffle = torch.nn.PixelShuffle(2) + model = torch.nn.Sequential(pixel_shuffle) + + quantized_pixel_shuffle = FakeQuantizationMixin.from_module(pixel_shuffle) + quantized_pixel_shuffle.input_quantizers[0] = QuantizeDequantize((1,), + bitwidth=8, + symmetric=False, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + quantized_pixel_shuffle.output_quantizers[0] = QuantizeDequantize((1,), + bitwidth=8, + symmetric=False, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + sim_model = torch.nn.Sequential(quantized_pixel_shuffle) dummy_input = torch.randn(1, 4, 8, 8) + with aimet_nn.compute_encodings(sim_model): + _ = sim_model(dummy_input) + # Save encodings with tempfile.TemporaryDirectory() as path: fname_no_prop = "encodings_propagation_false" @@ -238,13 +247,35 @@ def test_multi_output_onnx_op(self): model = ModelWith5Output() dummy_input = torch.randn(1, 3, 224, 224) sim_model = copy.deepcopy(model) - class DummyMixinWithDisabledOutput(DummyMixin): - def export_output_encodings(self): - enc = [self.dummy_encoding] - return [None] + ([enc] * (self.num_outputs - 1)) - sim_model.cust = DummyMixinWithDisabledOutput.from_module(sim_model.cust, num_inputs=1, num_outputs=5, - has_input_encodings=True, has_output_encodings=True) + @FakeQuantizationMixin.implements(ModuleWith5Output) + class FakeQuantizationMixinWithDisabledOutput(FakeQuantizationMixin, ModuleWith5Output): + def __quant_init__(self): + super().__quant_init__() + self.output_quantizers = torch.nn.ModuleList([None, None, None, None, None]) + + def quantized_forward(self, input): + if self.input_quantizers[0]: + input = self.input_quantizers[0](input) + outputs = super().forward(input) + return tuple( + quantizer(out) if quantizer else out + for out, quantizer in zip(outputs, self.output_quantizers) + ) + + sim_model.cust = FakeQuantizationMixin.from_module(sim_model.cust) + sim_model.cust.input_quantizers[0] = QuantizeDequantize((1,), + bitwidth=8, + symmetric=False, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + for i in range(1, 5): + sim_model.cust.output_quantizers[i] = QuantizeDequantize((1,), + bitwidth=8, + symmetric=False, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + + with aimet_nn.compute_encodings(sim_model): + _ = sim_model(dummy_input) with tempfile.TemporaryDirectory() as path: QuantizationSimModel.export_onnx_model_and_encodings(path, 'module_with_5_output', model, sim_model, @@ -255,8 +286,6 @@ def export_output_encodings(self): activation_encodings = json.load(json_file)['activation_encodings'] assert '7' not in activation_encodings assert set(['8', '9', '10', '11', 't.1']).issubset(activation_encodings.keys()) - for item in activation_encodings.values(): - assert item[0] == sim_model.cust.dummy_encoding # From: https://github.com/quic/aimet/blob/ce3dafe75d81893cdb8b45ba8abf53a672c28187/TrainingExtensions/torch/test/python/test_quantizer.py#L1935 def test_mapping_encoding_for_torch_module_with_multiple_onnx_ops(self): @@ -268,8 +297,28 @@ def test_mapping_encoding_for_torch_module_with_multiple_onnx_ops(self): model = SoftMaxAvgPoolModel() sim_model = copy.deepcopy(model) - sim_model.sfmax = DummyMixin.from_module(sim_model.sfmax, 1, 1, True, True) - sim_model.avgpool = DummyMixin.from_module(sim_model.avgpool, 1, 1, True, True) + sim_model.sfmax = FakeQuantizationMixin.from_module(sim_model.sfmax) + sim_model.sfmax.input_quantizers[0] = QuantizeDequantize((1,), + bitwidth=8, + symmetric=False, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + sim_model.sfmax.output_quantizers[0] = QuantizeDequantize((1,), + bitwidth=8, + symmetric=False, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + + sim_model.avgpool = FakeQuantizationMixin.from_module(sim_model.avgpool) + sim_model.avgpool.input_quantizers[0] = QuantizeDequantize((1,), + bitwidth=8, + symmetric=False, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + sim_model.avgpool.output_quantizers[0] = QuantizeDequantize((1,), + bitwidth=8, + symmetric=False, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + with aimet_nn.compute_encodings(sim_model): + _ = sim_model(dummy_input) + with tempfile.TemporaryDirectory() as path: QuantizationSimModel.export_onnx_model_and_encodings(path, "sfmaxavgpool_model", model, sim_model, dummy_input, export_args, propagate_encodings=False) @@ -278,7 +327,7 @@ def test_mapping_encoding_for_torch_module_with_multiple_onnx_ops(self): assert len(encoding_data["activation_encodings"]) == 3 - + @torch.no_grad() def test_conditional_export(self): """ Test exporting a model with conditional paths """ model = SimpleConditional() @@ -297,10 +346,29 @@ def forward_callback(model, _): qsim.compute_encodings(forward_callback, None) for name, module in sim_model.named_children(): + quantized_module = FakeQuantizationMixin.from_module(module) qsim_module = getattr(qsim.model, name) - has_input_encodings = qsim_module.input_quantizers[0].enabled - has_output_encodings = qsim_module.output_quantizers[0].enabled - sim_model.__setattr__(name, DummyMixin.from_module(module, 1, 1, has_input_encodings, has_output_encodings)) + if qsim_module.input_quantizers[0].enabled: + quantized_module.input_quantizers[0] = QuantizeDequantize((1,), + bitwidth=8, + symmetric=False, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + if qsim_module.output_quantizers[0].enabled: + quantized_module.output_quantizers[0] = QuantizeDequantize((1,), + bitwidth=8, + symmetric=False, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + for param_name, qsim_param_quantizer in qsim_module.param_quantizers.items(): + if not qsim_param_quantizer.enabled: + continue + quantized_module.param_quantizers[param_name] = QuantizeDequantize((1,), + bitwidth=4, + symmetric=True, + encoding_analyzer=MinMaxEncodingAnalyzer((1,))) + setattr(sim_model, name, quantized_module) + + with aimet_nn.compute_encodings(sim_model): + forward_callback(sim_model, None) qsim.model = sim_model