Skip to content

Commit

Permalink
Add blockwise quantization with tensor splitting utility
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Hsieh <quic_klhsieh@quicinc.com>
  • Loading branch information
quic-klhsieh authored Jan 17, 2024
1 parent 2444ab5 commit b9cb122
Show file tree
Hide file tree
Showing 2 changed files with 305 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
""" Utilities for implementing blockwise quantization using tensor splitting approach """

import torch
from aimet_torch import elementwise_ops
from aimet_torch.quantsim import QuantizationSimModel


class BlockwiseLinear(torch.nn.Module):
"""
Blockwise Linear implementation.
This module replaces a single nn.Linear module, breaking it into a number of smaller linear modules depending on
block size. Each separate linear module can operate with per channel quantization independently, and the outputs
of each linear module are summed up.
"""
def __init__(self, linear_module: torch.nn.Linear, block_size: int):
super(BlockwiseLinear, self).__init__()
self.block_size = block_size
self.linears = torch.nn.ModuleList()
self.elementwise_adds = torch.nn.ModuleList()
split_indices = list(range(block_size, linear_module.weight.shape[1], block_size))
self.split = elementwise_ops.Split()
split_weights = torch.tensor_split(linear_module.weight, split_indices, 1)
for idx, split_weight in enumerate(split_weights):
linear = torch.nn.Linear(split_weight.shape[1],
split_weight.shape[0],
bias=(linear_module.bias is not None and idx == 0))
linear.weight = torch.nn.Parameter(split_weight)
if linear.bias is not None:
linear.bias = linear_module.bias
self.linears.append(linear)
self.elementwise_adds.append(elementwise_ops.Add())
self.elementwise_adds = self.elementwise_adds[:-1]
if not self.elementwise_adds:
self.elementwise_adds = None

def forward(self, inp):
""" Forward pass """
if len(self.linears) == 1:
return self.linears[0](inp)

split_inputs = self.split(inp, self.block_size, -1)
out = None
for idx, split_input in enumerate(split_inputs):
linear_out = self.linears[idx](split_input)
if out is None:
out = linear_out
else:
out = self.elementwise_adds[idx-1](out, linear_out)
return out


def replace_linears_for_blockwise_quant(model: torch.nn.Module, block_size: int):
"""
Replace all instances of torch.nn.Linears in model with equivalent BlockwiseLinear modules.
The linear weights are split on the input channel dimension such that all constituent linear modules have weights
with input channel dimension <= block_size.
:param model: Model to replace nn.Linears for
:param block_size: Block size to use
"""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
setattr(model, name, BlockwiseLinear(module, block_size))


def tie_blockwise_linear_quantizers(quantsim: QuantizationSimModel):
"""
Tie all output quantizers within a BlockwiseLinear block together so they share the same encoding. In other words,
all output quantizers of constituent linear layers as well as output quantizers of elementwise add modules will
share the same quantization parameters.
:param quantsim: Quantsim model containing BlockwiseLinear modules to tie output quantizers for.
"""
for module in quantsim.model.modules():
if isinstance(module, BlockwiseLinear):
output_quantizer = module.linears[0].output_quantizers[0]
for linear in module.linears:
linear.output_quantizers[0] = output_quantizer
if module.elementwise_adds is not None:
for add in module.elementwise_adds:
add.output_quantizers[0] = output_quantizer
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
""" Tests for blockwise quant tensor split utility """

import json
import pytest
import torch
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.qc_quantize_op import QcQuantizeWrapper
from aimet_torch.blockwise_quant_tensor_split import (BlockwiseLinear, replace_linears_for_blockwise_quant,
tie_blockwise_linear_quantizers)

class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.linear1 = torch.nn.Linear(8, 3)
self.relu1 = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(3, 2, bias=False)
self.softmax = torch.nn.Softmax()

def forward(self, inp):
x = self.linear1(inp)
x = self.relu1(x)
x = self.linear2(x)
x = self.softmax(x)
return x


@pytest.mark.parametrize('model, dummy_input, block_size', [(torch.nn.Linear(8, 3), torch.randn(1, 8), 3),
(torch.nn.Linear(8, 3, bias=False), torch.randn(1, 8), 3),
(torch.nn.Linear(3, 2), torch.randn(1, 3), 3),
(torch.nn.Linear(3, 2), torch.randn(1, 3), 4)])
def test_blockwise_linears(model, dummy_input, block_size):
blockwise_linear = BlockwiseLinear(model, block_size=block_size)
orig_out = model(dummy_input)
new_out = blockwise_linear(dummy_input)
assert torch.allclose(orig_out, new_out, atol=1e-6)

def test_replace_linears_for_blockwise_quant():
dummy_input = torch.randn(1, 8)
model = LinearModel()
linear1 = model.linear1
orig_out = model(dummy_input)
replace_linears_for_blockwise_quant(model, 3)

assert len(model.linear1.linears) == 3
assert torch.equal(model.linear1.linears[0].bias, linear1.bias)
for linear in model.linear1.linears[1:]:
assert linear.bias is None

assert len(model.linear1.elementwise_adds) == 2
assert len(model.linear2.linears) == 1
assert model.linear2.elementwise_adds is None
new_out = model(dummy_input)
assert torch.allclose(orig_out, new_out, atol=1e-6)

def test_quantize_blockwise_linear():
quantsim_config = {
"defaults": {
"ops": {
"is_output_quantized": "True",
"is_symmetric": "False"
},
"params": {
"is_quantized": "False",
"is_symmetric": "True"
},
"per_channel_quantization": "True",
},
"params": {},
"op_type": {
'Split': {
'is_output_quantized': False
}
},
"supergroups": [],
"model_input": {},
"model_output": {}
}
with open('./data/quantsim_config.json', 'w') as f:
json.dump(quantsim_config, f)
dummy_input = torch.randn(1, 8)
model = BlockwiseLinear(torch.nn.Linear(8, 3), 3)
qsim = QuantizationSimModel(model, dummy_input=dummy_input)

assert isinstance(qsim.model.split, QcQuantizeWrapper)
assert isinstance(qsim.model.linears[0], QcQuantizeWrapper)
assert isinstance(qsim.model.linears[1], QcQuantizeWrapper)
assert isinstance(qsim.model.linears[2], QcQuantizeWrapper)
assert isinstance(qsim.model.elementwise_adds[0], QcQuantizeWrapper)
assert isinstance(qsim.model.elementwise_adds[1], QcQuantizeWrapper)

# Temporary hack to disable split op output quantizers while handling for CG split op is reworked
for output_quantizer in qsim.model.split.output_quantizers:
output_quantizer.enabled = False

# Temporary hack to enable model input split op input quantizer while handling for CG split op is reworked
qsim.model.split.input_quantizers[0].enabled = True

tie_blockwise_linear_quantizers(qsim)
qsim.compute_encodings(lambda m, _: m(dummy_input), None)
_ = qsim.model(dummy_input)

assert (qsim.model.linears[0].output_quantizers[0].encoding.max ==
qsim.model.linears[1].output_quantizers[0].encoding.max)
assert (qsim.model.linears[0].output_quantizers[0].encoding.max ==
qsim.model.linears[2].output_quantizers[0].encoding.max)
assert (qsim.model.linears[0].output_quantizers[0].encoding.max ==
qsim.model.elementwise_adds[0].output_quantizers[0].encoding.max)
assert (qsim.model.linears[0].output_quantizers[0].encoding.max ==
qsim.model.elementwise_adds[1].output_quantizers[0].encoding.max)

def test_blockwise_quant_with_small_linear():
quantsim_config = {
"defaults": {
"ops": {
"is_output_quantized": "True",
"is_symmetric": "False"
},
"params": {
"is_quantized": "False",
"is_symmetric": "True"
},
"per_channel_quantization": "True",
},
"params": {},
"op_type": {
'Split': {
'is_output_quantized': False
}
},
"supergroups": [],
"model_input": {},
"model_output": {}
}
with open('./data/quantsim_config.json', 'w') as f:
json.dump(quantsim_config, f)
dummy_input = torch.randn(1, 3)
model = BlockwiseLinear(torch.nn.Linear(3, 2), 3)
qsim = QuantizationSimModel(model, dummy_input=dummy_input)
# Temporary hack to disable split op output quantizers while handling for CG split op is reworked
for output_quantizer in qsim.model.split.output_quantizers:
output_quantizer.enabled = False

# Temporary hack to enable model input split op input quantizer while handling for CG split op is reworked
qsim.model.split.input_quantizers[0].enabled = True

tie_blockwise_linear_quantizers(qsim)
qsim.compute_encodings(lambda m, _: m(dummy_input), None)
assert len(qsim.model.linears) == 1
assert qsim.model.elementwise_adds is None

_ = qsim.model(dummy_input)
assert len(qsim.connected_graph.get_all_ops()) == 1

0 comments on commit b9cb122

Please sign in to comment.