Skip to content

Commit

Permalink
Make lazy quant wrapper handle nested input/output quantizers
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored and quic-akhobare committed Feb 7, 2024
1 parent 4f94c88 commit 1163cca
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,27 +46,13 @@
from torch.nn.utils.rnn import PackedSequence
from torch.utils._pytree import tree_map

from aimet_torch.experimental.v2.nn.quant_base import BaseQuantizationMixin
from aimet_torch.experimental.v2.nn.quant_base import BaseQuantizationMixin, _flatten_nn_module_list
from aimet_torch.experimental.v2.quantization.quantizers import QuantizerBase
from aimet_torch.experimental.v2.utils import patch_attr
import aimet_torch.elementwise_ops as aimet_ops



def _flatten_nn_module_list(module):
"""
Flatten nested list of nn.Modules into a flat list
"""
def flat_iter(mod):
if isinstance(mod, (list, tuple, nn.ModuleList)):
for x in mod:
yield from flat_iter(x)
else:
yield mod

return list(flat_iter(module))


class FakeQuantizationMixin(BaseQuantizationMixin): # pylint: disable=abstract-method
"""
Mixin that implements fake-quantization on top of regular pytorch modules.
Expand Down Expand Up @@ -520,7 +506,7 @@ def __quant_init__(self):
self.input_quantizers = nn.ModuleList([None, nn.ModuleList([None, None])])
self.output_quantizers = nn.ModuleList([None, nn.ModuleList([None, None])])

def quantized_forward(self, input, hx: Optional[Tuple[Tensor, Tensor]] = None): # pylint: disable=arguments-differ, too-many-branches
def quantized_forward(self, input, hx: Optional[Tuple[Tensor, Tensor]] = None): # pylint: disable=arguments-differ
"""
Quantized forward impl for nn.LSTM.
"""
Expand All @@ -533,12 +519,7 @@ def quantized_forward(self, input, hx: Optional[Tuple[Tensor, Tensor]] = None):

if hx is not None:
h, c = hx
if isinstance(self.input_quantizers[1], QuantizerBase):
# For backward compatibility with V1 quantsim.
# Quantsim V1 uses single input quantizer for h and c
h_quantizer = c_quantizer = self.input_quantizers[1]
else:
h_quantizer, c_quantizer = self.input_quantizers[1]
h_quantizer, c_quantizer = self.input_quantizers[1]

if h_quantizer:
h = h_quantizer(h)
Expand All @@ -558,12 +539,7 @@ def quantized_forward(self, input, hx: Optional[Tuple[Tensor, Tensor]] = None):
output = self.output_quantizers[0](output)

h_n, c_n = hidden
if isinstance(self.output_quantizers[1], QuantizerBase):
# For backward compatibility with V1 quantsim.
# Quantsim V1 uses single output quantizer for h_n and c_n
h_quantizer = c_quantizer = self.output_quantizers[1]
else:
h_quantizer, c_quantizer = self.output_quantizers[1]
h_quantizer, c_quantizer = self.output_quantizers[1]

if h_quantizer:
h_n = h_quantizer(h_n)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,20 @@
from aimet_torch.experimental.v2.utils import patch_attr


def _flatten_nn_module_list(module):
"""
Flatten nested list of nn.Modules into a flat list
"""
def flat_iter(mod):
if isinstance(mod, (list, tuple, nn.ModuleList)):
for x in mod:
yield from flat_iter(x)
else:
yield mod

return list(flat_iter(module))


class BaseQuantizationMixin(abc.ABC):
"""
Mixin that implements quantization on top of regular pytorch modules.
Expand Down Expand Up @@ -114,7 +128,10 @@ def compute_encodings(self):
self._compute_param_encodings(overwrite=True)

with contextlib.ExitStack() as stack:
for quantizer in itertools.chain(self.input_quantizers, self.output_quantizers):
input_quantizers = _flatten_nn_module_list(self.input_quantizers)
output_quantizers = _flatten_nn_module_list(self.output_quantizers)

for quantizer in itertools.chain(input_quantizers, output_quantizers):
if not quantizer:
continue
ctx = quantizer.compute_encodings()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,37 @@ def realize_v2_wrapper(self) -> FakeQuantizationMixin:
"""
quantized_module = FakeQuantizationMixin.from_module(self._module_to_wrap)

quantized_module.input_quantizers = nn.ModuleList([
quant_builder.realize() for quant_builder in self.input_quantizers
])
quantized_module.output_quantizers = nn.ModuleList([
quant_builder.realize() for quant_builder in self.output_quantizers
])
def set_recursive(module_list, i, quantizer):
"""
Set quantizer recursively.
AIMET V1 handles nested input/output tensors with single input quantizers,
whereas V2 quantized module allows having nested input/output quantizers.
(For reference, see the class definition of FakeQuantizedLSTM in fake_quant.py)
To implement V1 behavior, we set the nested input/output quantizers to
share the same single quantizer, for example as below:
- self.input_quantizers = [q1, q2, q3]
- quant_module.input_quantizers = [None, [None, None], None]
(before set_recursive)
- quant_module.input_quantizers = [q1, [q2, q2], q3]
(after set_recursive)
"""
if module_list[i] is None:
module_list[i] = quantizer
elif isinstance(module_list[i], nn.ModuleList):
for j in range(len(module_list[i])):
set_recursive(module_list[i], j, quantizer)
else:
raise RuntimeError

for i, quant_builder in enumerate(self.input_quantizers):
quantizer = quant_builder.realize()
set_recursive(quantized_module.input_quantizers, i, quantizer)

for i, quant_builder in enumerate(self.output_quantizers):
quantizer = quant_builder.realize()
set_recursive(quantized_module.output_quantizers, i, quantizer)

for param_name, quant_builder in self.param_quantizers.items():
quantized_module.param_quantizers[param_name] = quant_builder.realize()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1637,6 +1637,19 @@ def test_rnn_quantization(self):
quant_scheme=QuantScheme.post_training_tf)
assert isinstance(sim.model.rnn, aimet_nn.FakeQuantizedRNN)

sim.compute_encodings(lambda model, _: model(dummy_input), None) # Should not throw error

def test_lstm_quantization(self):
""" Test quantizing a model with rnn layer """
model = TwoLayerBidirectionalLSTMModel()
dummy_input = torch.randn(10, 1, 3)

sim = QuantizationSimModel(model, dummy_input,
quant_scheme=QuantScheme.post_training_tf)
assert isinstance(sim.model.recurrent, aimet_nn.FakeQuantizedLSTM)

sim.compute_encodings(lambda model, _: model(dummy_input), None) # Should not throw error

def test_quantizing_qc_quantize_module(self):
""" Test that qc_quantize_module is identified as not quantizable """
q_rnn = aimet_nn.FakeQuantizedRNN(input_size=3, hidden_size=5, num_layers=1)
Expand Down

0 comments on commit 1163cca

Please sign in to comment.