Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finalize integration between V1 export and V2 fake-quantized modules #2686

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,24 @@
# =============================================================================

# pylint: disable=missing-docstring
import contextlib
import torch
from .fake_quant import *


@contextlib.contextmanager
def compute_encodings(model: torch.nn.Module):
"""
Compute encodings of all quantized modules in the model
"""
with contextlib.ExitStack() as stack:
for module in model.modules():
if isinstance(module, BaseQuantizationMixin):
ctx = module.compute_encodings()
stack.enter_context(ctx)

if isinstance(model, BaseQuantizationMixin):
ctx = model.compute_encodings()
stack.enter_context(ctx)

yield
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,90 @@
"""Fake-quantized modules"""

from collections import OrderedDict
from typing import Type, Optional, Tuple
from typing import Type, Optional, Tuple, List, Dict

from torch import Tensor
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence
from torch.utils._pytree import tree_map

from aimet_torch.experimental.v2.nn.quant_base import BaseQuantizationMixin
from aimet_torch.experimental.v2.quantization.modules.quantize import _QuantizerBase
import aimet_torch.elementwise_ops as aimet_ops



def _flatten_nn_module_list(module):
"""
Flatten nested list of nn.Modules into a flat list
"""
def flat_iter(mod):
if isinstance(mod, (list, tuple, nn.ModuleList)):
for x in mod:
yield from flat_iter(x)
else:
yield mod

return list(flat_iter(module))


class FakeQuantizationMixin(BaseQuantizationMixin): # pylint: disable=abstract-method
"""
Mixin that implements fake-quantization on top of regular pytorch modules.
"""

# Mapping from a base module class to quantized module class
quantized_classes_map = OrderedDict()
cls_to_qcls = OrderedDict() # ouantized class -> original class
qcls_to_cls = OrderedDict() # original class -> quantized class

def export_input_encodings(self) -> List[List[Dict]]:
"""
Returns a list of input encodings, each represented as a List of Dicts
"""
return [
quantizer.get_encodings() if isinstance(quantizer, _QuantizerBase) else None
for quantizer in _flatten_nn_module_list(self.input_quantizers)
]

def export_output_encodings(self) -> List[List[Dict]]:
"""
Returns a list of output encodings, each represented as a List of Dicts
"""
return [
quantizer.get_encodings() if isinstance(quantizer, _QuantizerBase) else None
for quantizer in _flatten_nn_module_list(self.output_quantizers)
]

def export_param_encodings(self) -> Dict[str, List[Dict]]:
"""
Returns a dict of {param name: param encodings}, with each encoding represented as a List of Dicts
"""
return {
param_name: quantizer.get_encodings() if isinstance(quantizer, _QuantizerBase) else None
for param_name, quantizer in self.param_quantizers.items()
}

def get_original_module(self) -> nn.Module:
"""
Returns the floating point version of quantized module
"""
# pylint: disable=protected-access

qtzn_module_cls = type(self)
orig_module_cls = self.qcls_to_cls.get(qtzn_module_cls)

orig_module = self.__new__(orig_module_cls)
orig_module.__dict__ = self.__dict__.copy()
del orig_module.__dict__['_forward']
del orig_module.__dict__['forward']

orig_module._parameters = self._parameters.copy()
orig_module._buffers = self._buffers.copy()
orig_module._modules = self._modules.copy()
del orig_module._modules['input_quantizers']
del orig_module._modules['output_quantizers']
del orig_module._modules['param_quantizers']

return orig_module

@classmethod
def wrap(cls, module_cls: Type[nn.Module]) -> Type[nn.Module]:
Expand All @@ -64,8 +130,8 @@ def wrap(cls, module_cls: Type[nn.Module]) -> Type[nn.Module]:
if not issubclass(module_cls, nn.Module):
raise ValueError("Expected module_cls to be a subclass of torch.nn.Module. "
f"Got {module_cls}.")
if module_cls in cls.quantized_classes_map:
return cls.quantized_classes_map[module_cls]
if module_cls in cls.cls_to_qcls:
return cls.cls_to_qcls[module_cls]

quantized_cls_name = f"FakeQuantized{module_cls.__name__}"
base_classes = (cls, module_cls)
Expand All @@ -78,7 +144,8 @@ def implements(cls, module_cls):
Decorator for registering fake-quantized implementation of the given base class.
"""
def wrapper(quantized_cls):
cls.quantized_classes_map[module_cls] = quantized_cls
cls.cls_to_qcls[module_cls] = quantized_cls
cls.qcls_to_cls[quantized_cls] = module_cls
return quantized_cls
return wrapper

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def from_module(cls, module: nn.Module):
"""
# pylint: disable=protected-access
module_cls = type(module)
qtzn_module_cls = cls.quantized_classes_map.get(module_cls, None)
qtzn_module_cls = cls.cls_to_qcls.get(module_cls, None)

if not qtzn_module_cls:
raise RuntimeError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
""" nn.Modules for quantization operators """

import copy
from typing import Optional, Tuple
from typing import Optional, Tuple, List, Dict
import contextlib
from collections import OrderedDict
import functools
Expand Down Expand Up @@ -226,6 +226,31 @@ def get_offset(self) -> Optional[torch.Tensor]:

return offset

@torch.no_grad()
def get_encodings(self) -> Optional[List[Dict]]:
"""
Returns a list of encodings, each represented as a List of Dicts
"""
# pylint: disable=redefined-builtin

if not self.is_initialized():
return None

min = self.get_min().flatten()
max = self.get_max().flatten()
scale = self.get_scale().flatten()
offset = self.get_offset().flatten()
bitwidth = self.bitwidth
dtype = "int"
is_symmetric = self.symmetric

return [
{'min': float(min_), 'max': float(max_),
'scale': float(scale_), 'offset': float(offset_),
'bitwidth': bitwidth, 'dtype': dtype, 'is_symmetric': str(is_symmetric)}
for min_, max_, scale_, offset_ in zip(min, max, scale, offset)
]

@contextlib.contextmanager
def compute_encodings(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@
# =============================================================================
""" Top level API for performing quantization simulation of a pytorch model """

import contextlib

import torch

from aimet_torch.quantsim import QuantizationSimModel as V1QuantizationSimModel
from aimet_torch.experimental.v2 import nn as aimet_nn
from aimet_torch.experimental.v2.nn.fake_quant import FakeQuantizationMixin
from aimet_torch.experimental.v2.quantization.wrappers.builder import LazyQuantizeWrapper
from aimet_torch import utils
Expand Down Expand Up @@ -78,10 +77,5 @@ def compute_encodings(self, forward_pass_callback, forward_pass_callback_args):
"""
# Run forward iterations so we can collect statistics to compute the appropriate encodings
with utils.in_eval_mode(self.model), torch.no_grad():
with contextlib.ExitStack() as stack:
for module in self.model.modules():
if not isinstance(module, FakeQuantizationMixin):
continue
stack.enter_context(module.compute_encodings())

with aimet_nn.compute_encodings(self.model):
_ = forward_pass_callback(self.model, forward_pass_callback_args)
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def quantized_forward(self, x):

finally:
# Unregister CustomOp so as not to affect other test functions
FakeQuantizationMixin.quantized_classes_map.pop(CustomOp)
FakeQuantizationMixin.cls_to_qcls.pop(CustomOp)

def test_custom_op_wrap_registered(self):
try:
Expand All @@ -85,4 +85,4 @@ def quantized_forward(self, x):

finally:
# Unregister CustomOp so as not to affect other test functions
FakeQuantizationMixin.quantized_classes_map.pop(CustomOp)
FakeQuantizationMixin.cls_to_qcls.pop(CustomOp)
Loading
Loading