diff --git a/TrainingExtensions/torch/src/python/aimet_torch/blockwise_quant_tensor_split.py b/TrainingExtensions/torch/src/python/aimet_torch/blockwise_quant_tensor_split.py new file mode 100644 index 00000000000..813d18c8ad1 --- /dev/null +++ b/TrainingExtensions/torch/src/python/aimet_torch/blockwise_quant_tensor_split.py @@ -0,0 +1,117 @@ +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +""" Utilities for implementing blockwise quantization using tensor splitting approach """ + +import torch +from aimet_torch import elementwise_ops +from aimet_torch.quantsim import QuantizationSimModel + + +class BlockwiseLinear(torch.nn.Module): + """ + Blockwise Linear implementation. + This module replaces a single nn.Linear module, breaking it into a number of smaller linear modules depending on + block size. Each separate linear module can operate with per channel quantization independently, and the outputs + of each linear module are summed up. + """ + def __init__(self, linear_module: torch.nn.Linear, block_size: int): + super(BlockwiseLinear, self).__init__() + self.block_size = block_size + self.linears = torch.nn.ModuleList() + self.elementwise_adds = torch.nn.ModuleList() + split_indices = list(range(block_size, linear_module.weight.shape[1], block_size)) + self.split = elementwise_ops.Split() + split_weights = torch.tensor_split(linear_module.weight, split_indices, 1) + for idx, split_weight in enumerate(split_weights): + linear = torch.nn.Linear(split_weight.shape[1], + split_weight.shape[0], + bias=(linear_module.bias is not None and idx == 0)) + linear.weight = torch.nn.Parameter(split_weight) + if linear.bias is not None: + linear.bias = linear_module.bias + self.linears.append(linear) + self.elementwise_adds.append(elementwise_ops.Add()) + self.elementwise_adds = self.elementwise_adds[:-1] + if not self.elementwise_adds: + self.elementwise_adds = None + + def forward(self, inp): + """ Forward pass """ + if len(self.linears) == 1: + return self.linears[0](inp) + + split_inputs = self.split(inp, self.block_size, -1) + out = None + for idx, split_input in enumerate(split_inputs): + linear_out = self.linears[idx](split_input) + if out is None: + out = linear_out + else: + out = self.elementwise_adds[idx-1](out, linear_out) + return out + + +def replace_linears_for_blockwise_quant(model: torch.nn.Module, block_size: int): + """ + Replace all instances of torch.nn.Linears in model with equivalent BlockwiseLinear modules. + The linear weights are split on the input channel dimension such that all constituent linear modules have weights + with input channel dimension <= block_size. + + :param model: Model to replace nn.Linears for + :param block_size: Block size to use + """ + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + setattr(model, name, BlockwiseLinear(module, block_size)) + + +def tie_blockwise_linear_quantizers(quantsim: QuantizationSimModel): + """ + Tie all output quantizers within a BlockwiseLinear block together so they share the same encoding. In other words, + all output quantizers of constituent linear layers as well as output quantizers of elementwise add modules will + share the same quantization parameters. + + :param quantsim: Quantsim model containing BlockwiseLinear modules to tie output quantizers for. + """ + for module in quantsim.model.modules(): + if isinstance(module, BlockwiseLinear): + output_quantizer = module.linears[0].output_quantizers[0] + for linear in module.linears: + linear.output_quantizers[0] = output_quantizer + if module.elementwise_adds is not None: + for add in module.elementwise_adds: + add.output_quantizers[0] = output_quantizer diff --git a/TrainingExtensions/torch/test/python/test_blockwise_quant_tensor_split.py b/TrainingExtensions/torch/test/python/test_blockwise_quant_tensor_split.py new file mode 100644 index 00000000000..2b222c5dea3 --- /dev/null +++ b/TrainingExtensions/torch/test/python/test_blockwise_quant_tensor_split.py @@ -0,0 +1,188 @@ +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +""" Tests for blockwise quant tensor split utility """ + +import json +import pytest +import torch +from aimet_torch.quantsim import QuantizationSimModel +from aimet_torch.qc_quantize_op import QcQuantizeWrapper +from aimet_torch.blockwise_quant_tensor_split import (BlockwiseLinear, replace_linears_for_blockwise_quant, + tie_blockwise_linear_quantizers) + +class LinearModel(torch.nn.Module): + def __init__(self): + super(LinearModel, self).__init__() + self.linear1 = torch.nn.Linear(8, 3) + self.relu1 = torch.nn.ReLU() + self.linear2 = torch.nn.Linear(3, 2, bias=False) + self.softmax = torch.nn.Softmax() + + def forward(self, inp): + x = self.linear1(inp) + x = self.relu1(x) + x = self.linear2(x) + x = self.softmax(x) + return x + + +@pytest.mark.parametrize('model, dummy_input, block_size', [(torch.nn.Linear(8, 3), torch.randn(1, 8), 3), + (torch.nn.Linear(8, 3, bias=False), torch.randn(1, 8), 3), + (torch.nn.Linear(3, 2), torch.randn(1, 3), 3), + (torch.nn.Linear(3, 2), torch.randn(1, 3), 4)]) +def test_blockwise_linears(model, dummy_input, block_size): + blockwise_linear = BlockwiseLinear(model, block_size=block_size) + orig_out = model(dummy_input) + new_out = blockwise_linear(dummy_input) + assert torch.allclose(orig_out, new_out, atol=1e-6) + +def test_replace_linears_for_blockwise_quant(): + dummy_input = torch.randn(1, 8) + model = LinearModel() + linear1 = model.linear1 + orig_out = model(dummy_input) + replace_linears_for_blockwise_quant(model, 3) + + assert len(model.linear1.linears) == 3 + assert torch.equal(model.linear1.linears[0].bias, linear1.bias) + for linear in model.linear1.linears[1:]: + assert linear.bias is None + + assert len(model.linear1.elementwise_adds) == 2 + assert len(model.linear2.linears) == 1 + assert model.linear2.elementwise_adds is None + new_out = model(dummy_input) + assert torch.allclose(orig_out, new_out, atol=1e-6) + +def test_quantize_blockwise_linear(): + quantsim_config = { + "defaults": { + "ops": { + "is_output_quantized": "True", + "is_symmetric": "False" + }, + "params": { + "is_quantized": "False", + "is_symmetric": "True" + }, + "per_channel_quantization": "True", + }, + "params": {}, + "op_type": { + 'Split': { + 'is_output_quantized': False + } + }, + "supergroups": [], + "model_input": {}, + "model_output": {} + } + with open('./data/quantsim_config.json', 'w') as f: + json.dump(quantsim_config, f) + dummy_input = torch.randn(1, 8) + model = BlockwiseLinear(torch.nn.Linear(8, 3), 3) + qsim = QuantizationSimModel(model, dummy_input=dummy_input) + + assert isinstance(qsim.model.split, QcQuantizeWrapper) + assert isinstance(qsim.model.linears[0], QcQuantizeWrapper) + assert isinstance(qsim.model.linears[1], QcQuantizeWrapper) + assert isinstance(qsim.model.linears[2], QcQuantizeWrapper) + assert isinstance(qsim.model.elementwise_adds[0], QcQuantizeWrapper) + assert isinstance(qsim.model.elementwise_adds[1], QcQuantizeWrapper) + + # Temporary hack to disable split op output quantizers while handling for CG split op is reworked + for output_quantizer in qsim.model.split.output_quantizers: + output_quantizer.enabled = False + + # Temporary hack to enable model input split op input quantizer while handling for CG split op is reworked + qsim.model.split.input_quantizers[0].enabled = True + + tie_blockwise_linear_quantizers(qsim) + qsim.compute_encodings(lambda m, _: m(dummy_input), None) + _ = qsim.model(dummy_input) + + assert (qsim.model.linears[0].output_quantizers[0].encoding.max == + qsim.model.linears[1].output_quantizers[0].encoding.max) + assert (qsim.model.linears[0].output_quantizers[0].encoding.max == + qsim.model.linears[2].output_quantizers[0].encoding.max) + assert (qsim.model.linears[0].output_quantizers[0].encoding.max == + qsim.model.elementwise_adds[0].output_quantizers[0].encoding.max) + assert (qsim.model.linears[0].output_quantizers[0].encoding.max == + qsim.model.elementwise_adds[1].output_quantizers[0].encoding.max) + +def test_blockwise_quant_with_small_linear(): + quantsim_config = { + "defaults": { + "ops": { + "is_output_quantized": "True", + "is_symmetric": "False" + }, + "params": { + "is_quantized": "False", + "is_symmetric": "True" + }, + "per_channel_quantization": "True", + }, + "params": {}, + "op_type": { + 'Split': { + 'is_output_quantized': False + } + }, + "supergroups": [], + "model_input": {}, + "model_output": {} + } + with open('./data/quantsim_config.json', 'w') as f: + json.dump(quantsim_config, f) + dummy_input = torch.randn(1, 3) + model = BlockwiseLinear(torch.nn.Linear(3, 2), 3) + qsim = QuantizationSimModel(model, dummy_input=dummy_input) + # Temporary hack to disable split op output quantizers while handling for CG split op is reworked + for output_quantizer in qsim.model.split.output_quantizers: + output_quantizer.enabled = False + + # Temporary hack to enable model input split op input quantizer while handling for CG split op is reworked + qsim.model.split.input_quantizers[0].enabled = True + + tie_blockwise_linear_quantizers(qsim) + qsim.compute_encodings(lambda m, _: m(dummy_input), None) + assert len(qsim.model.linears) == 1 + assert qsim.model.elementwise_adds is None + + _ = qsim.model(dummy_input) + assert len(qsim.connected_graph.get_all_ops()) == 1