-
Notifications
You must be signed in to change notification settings - Fork 393
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add blockwise quantization with tensor splitting utility
Signed-off-by: Kevin Hsieh <quic_klhsieh@quicinc.com>
- Loading branch information
1 parent
2444ab5
commit b9cb122
Showing
2 changed files
with
305 additions
and
0 deletions.
There are no files selected for viewing
117 changes: 117 additions & 0 deletions
117
TrainingExtensions/torch/src/python/aimet_torch/blockwise_quant_tensor_split.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# -*- mode: python -*- | ||
# ============================================================================= | ||
# @@-COPYRIGHT-START-@@ | ||
# | ||
# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions are met: | ||
# | ||
# 1. Redistributions of source code must retain the above copyright notice, | ||
# this list of conditions and the following disclaimer. | ||
# | ||
# 2. Redistributions in binary form must reproduce the above copyright notice, | ||
# this list of conditions and the following disclaimer in the documentation | ||
# and/or other materials provided with the distribution. | ||
# | ||
# 3. Neither the name of the copyright holder nor the names of its contributors | ||
# may be used to endorse or promote products derived from this software | ||
# without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||
# POSSIBILITY OF SUCH DAMAGE. | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# @@-COPYRIGHT-END-@@ | ||
# ============================================================================= | ||
""" Utilities for implementing blockwise quantization using tensor splitting approach """ | ||
|
||
import torch | ||
from aimet_torch import elementwise_ops | ||
from aimet_torch.quantsim import QuantizationSimModel | ||
|
||
|
||
class BlockwiseLinear(torch.nn.Module): | ||
""" | ||
Blockwise Linear implementation. | ||
This module replaces a single nn.Linear module, breaking it into a number of smaller linear modules depending on | ||
block size. Each separate linear module can operate with per channel quantization independently, and the outputs | ||
of each linear module are summed up. | ||
""" | ||
def __init__(self, linear_module: torch.nn.Linear, block_size: int): | ||
super(BlockwiseLinear, self).__init__() | ||
self.block_size = block_size | ||
self.linears = torch.nn.ModuleList() | ||
self.elementwise_adds = torch.nn.ModuleList() | ||
split_indices = list(range(block_size, linear_module.weight.shape[1], block_size)) | ||
self.split = elementwise_ops.Split() | ||
split_weights = torch.tensor_split(linear_module.weight, split_indices, 1) | ||
for idx, split_weight in enumerate(split_weights): | ||
linear = torch.nn.Linear(split_weight.shape[1], | ||
split_weight.shape[0], | ||
bias=(linear_module.bias is not None and idx == 0)) | ||
linear.weight = torch.nn.Parameter(split_weight) | ||
if linear.bias is not None: | ||
linear.bias = linear_module.bias | ||
self.linears.append(linear) | ||
self.elementwise_adds.append(elementwise_ops.Add()) | ||
self.elementwise_adds = self.elementwise_adds[:-1] | ||
if not self.elementwise_adds: | ||
self.elementwise_adds = None | ||
|
||
def forward(self, inp): | ||
""" Forward pass """ | ||
if len(self.linears) == 1: | ||
return self.linears[0](inp) | ||
|
||
split_inputs = self.split(inp, self.block_size, -1) | ||
out = None | ||
for idx, split_input in enumerate(split_inputs): | ||
linear_out = self.linears[idx](split_input) | ||
if out is None: | ||
out = linear_out | ||
else: | ||
out = self.elementwise_adds[idx-1](out, linear_out) | ||
return out | ||
|
||
|
||
def replace_linears_for_blockwise_quant(model: torch.nn.Module, block_size: int): | ||
""" | ||
Replace all instances of torch.nn.Linears in model with equivalent BlockwiseLinear modules. | ||
The linear weights are split on the input channel dimension such that all constituent linear modules have weights | ||
with input channel dimension <= block_size. | ||
:param model: Model to replace nn.Linears for | ||
:param block_size: Block size to use | ||
""" | ||
for name, module in model.named_modules(): | ||
if isinstance(module, torch.nn.Linear): | ||
setattr(model, name, BlockwiseLinear(module, block_size)) | ||
|
||
|
||
def tie_blockwise_linear_quantizers(quantsim: QuantizationSimModel): | ||
""" | ||
Tie all output quantizers within a BlockwiseLinear block together so they share the same encoding. In other words, | ||
all output quantizers of constituent linear layers as well as output quantizers of elementwise add modules will | ||
share the same quantization parameters. | ||
:param quantsim: Quantsim model containing BlockwiseLinear modules to tie output quantizers for. | ||
""" | ||
for module in quantsim.model.modules(): | ||
if isinstance(module, BlockwiseLinear): | ||
output_quantizer = module.linears[0].output_quantizers[0] | ||
for linear in module.linears: | ||
linear.output_quantizers[0] = output_quantizer | ||
if module.elementwise_adds is not None: | ||
for add in module.elementwise_adds: | ||
add.output_quantizers[0] = output_quantizer |
188 changes: 188 additions & 0 deletions
188
TrainingExtensions/torch/test/python/test_blockwise_quant_tensor_split.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
# -*- mode: python -*- | ||
# ============================================================================= | ||
# @@-COPYRIGHT-START-@@ | ||
# | ||
# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions are met: | ||
# | ||
# 1. Redistributions of source code must retain the above copyright notice, | ||
# this list of conditions and the following disclaimer. | ||
# | ||
# 2. Redistributions in binary form must reproduce the above copyright notice, | ||
# this list of conditions and the following disclaimer in the documentation | ||
# and/or other materials provided with the distribution. | ||
# | ||
# 3. Neither the name of the copyright holder nor the names of its contributors | ||
# may be used to endorse or promote products derived from this software | ||
# without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||
# POSSIBILITY OF SUCH DAMAGE. | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# @@-COPYRIGHT-END-@@ | ||
# ============================================================================= | ||
""" Tests for blockwise quant tensor split utility """ | ||
|
||
import json | ||
import pytest | ||
import torch | ||
from aimet_torch.quantsim import QuantizationSimModel | ||
from aimet_torch.qc_quantize_op import QcQuantizeWrapper | ||
from aimet_torch.blockwise_quant_tensor_split import (BlockwiseLinear, replace_linears_for_blockwise_quant, | ||
tie_blockwise_linear_quantizers) | ||
|
||
class LinearModel(torch.nn.Module): | ||
def __init__(self): | ||
super(LinearModel, self).__init__() | ||
self.linear1 = torch.nn.Linear(8, 3) | ||
self.relu1 = torch.nn.ReLU() | ||
self.linear2 = torch.nn.Linear(3, 2, bias=False) | ||
self.softmax = torch.nn.Softmax() | ||
|
||
def forward(self, inp): | ||
x = self.linear1(inp) | ||
x = self.relu1(x) | ||
x = self.linear2(x) | ||
x = self.softmax(x) | ||
return x | ||
|
||
|
||
@pytest.mark.parametrize('model, dummy_input, block_size', [(torch.nn.Linear(8, 3), torch.randn(1, 8), 3), | ||
(torch.nn.Linear(8, 3, bias=False), torch.randn(1, 8), 3), | ||
(torch.nn.Linear(3, 2), torch.randn(1, 3), 3), | ||
(torch.nn.Linear(3, 2), torch.randn(1, 3), 4)]) | ||
def test_blockwise_linears(model, dummy_input, block_size): | ||
blockwise_linear = BlockwiseLinear(model, block_size=block_size) | ||
orig_out = model(dummy_input) | ||
new_out = blockwise_linear(dummy_input) | ||
assert torch.allclose(orig_out, new_out, atol=1e-6) | ||
|
||
def test_replace_linears_for_blockwise_quant(): | ||
dummy_input = torch.randn(1, 8) | ||
model = LinearModel() | ||
linear1 = model.linear1 | ||
orig_out = model(dummy_input) | ||
replace_linears_for_blockwise_quant(model, 3) | ||
|
||
assert len(model.linear1.linears) == 3 | ||
assert torch.equal(model.linear1.linears[0].bias, linear1.bias) | ||
for linear in model.linear1.linears[1:]: | ||
assert linear.bias is None | ||
|
||
assert len(model.linear1.elementwise_adds) == 2 | ||
assert len(model.linear2.linears) == 1 | ||
assert model.linear2.elementwise_adds is None | ||
new_out = model(dummy_input) | ||
assert torch.allclose(orig_out, new_out, atol=1e-6) | ||
|
||
def test_quantize_blockwise_linear(): | ||
quantsim_config = { | ||
"defaults": { | ||
"ops": { | ||
"is_output_quantized": "True", | ||
"is_symmetric": "False" | ||
}, | ||
"params": { | ||
"is_quantized": "False", | ||
"is_symmetric": "True" | ||
}, | ||
"per_channel_quantization": "True", | ||
}, | ||
"params": {}, | ||
"op_type": { | ||
'Split': { | ||
'is_output_quantized': False | ||
} | ||
}, | ||
"supergroups": [], | ||
"model_input": {}, | ||
"model_output": {} | ||
} | ||
with open('./data/quantsim_config.json', 'w') as f: | ||
json.dump(quantsim_config, f) | ||
dummy_input = torch.randn(1, 8) | ||
model = BlockwiseLinear(torch.nn.Linear(8, 3), 3) | ||
qsim = QuantizationSimModel(model, dummy_input=dummy_input) | ||
|
||
assert isinstance(qsim.model.split, QcQuantizeWrapper) | ||
assert isinstance(qsim.model.linears[0], QcQuantizeWrapper) | ||
assert isinstance(qsim.model.linears[1], QcQuantizeWrapper) | ||
assert isinstance(qsim.model.linears[2], QcQuantizeWrapper) | ||
assert isinstance(qsim.model.elementwise_adds[0], QcQuantizeWrapper) | ||
assert isinstance(qsim.model.elementwise_adds[1], QcQuantizeWrapper) | ||
|
||
# Temporary hack to disable split op output quantizers while handling for CG split op is reworked | ||
for output_quantizer in qsim.model.split.output_quantizers: | ||
output_quantizer.enabled = False | ||
|
||
# Temporary hack to enable model input split op input quantizer while handling for CG split op is reworked | ||
qsim.model.split.input_quantizers[0].enabled = True | ||
|
||
tie_blockwise_linear_quantizers(qsim) | ||
qsim.compute_encodings(lambda m, _: m(dummy_input), None) | ||
_ = qsim.model(dummy_input) | ||
|
||
assert (qsim.model.linears[0].output_quantizers[0].encoding.max == | ||
qsim.model.linears[1].output_quantizers[0].encoding.max) | ||
assert (qsim.model.linears[0].output_quantizers[0].encoding.max == | ||
qsim.model.linears[2].output_quantizers[0].encoding.max) | ||
assert (qsim.model.linears[0].output_quantizers[0].encoding.max == | ||
qsim.model.elementwise_adds[0].output_quantizers[0].encoding.max) | ||
assert (qsim.model.linears[0].output_quantizers[0].encoding.max == | ||
qsim.model.elementwise_adds[1].output_quantizers[0].encoding.max) | ||
|
||
def test_blockwise_quant_with_small_linear(): | ||
quantsim_config = { | ||
"defaults": { | ||
"ops": { | ||
"is_output_quantized": "True", | ||
"is_symmetric": "False" | ||
}, | ||
"params": { | ||
"is_quantized": "False", | ||
"is_symmetric": "True" | ||
}, | ||
"per_channel_quantization": "True", | ||
}, | ||
"params": {}, | ||
"op_type": { | ||
'Split': { | ||
'is_output_quantized': False | ||
} | ||
}, | ||
"supergroups": [], | ||
"model_input": {}, | ||
"model_output": {} | ||
} | ||
with open('./data/quantsim_config.json', 'w') as f: | ||
json.dump(quantsim_config, f) | ||
dummy_input = torch.randn(1, 3) | ||
model = BlockwiseLinear(torch.nn.Linear(3, 2), 3) | ||
qsim = QuantizationSimModel(model, dummy_input=dummy_input) | ||
# Temporary hack to disable split op output quantizers while handling for CG split op is reworked | ||
for output_quantizer in qsim.model.split.output_quantizers: | ||
output_quantizer.enabled = False | ||
|
||
# Temporary hack to enable model input split op input quantizer while handling for CG split op is reworked | ||
qsim.model.split.input_quantizers[0].enabled = True | ||
|
||
tie_blockwise_linear_quantizers(qsim) | ||
qsim.compute_encodings(lambda m, _: m(dummy_input), None) | ||
assert len(qsim.model.linears) == 1 | ||
assert qsim.model.elementwise_adds is None | ||
|
||
_ = qsim.model(dummy_input) | ||
assert len(qsim.connected_graph.get_all_ops()) == 1 |