Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu committed Oct 14, 2024
1 parent ab429b4 commit b3bc4b2
Showing 1 changed file with 106 additions and 2 deletions.
108 changes: 106 additions & 2 deletions TrainingExtensions/torch/test/python/v2/quantsim/test_quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v2.quantization.encoding_analyzer import PercentileEncodingAnalyzer
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantization.affine import AffineQuantizerBase, GroupedBlockQuantizeDequantize
from aimet_torch.v2.quantization.affine import AffineQuantizerBase, GroupedBlockQuantizeDequantize, QuantizeDequantize
from aimet_torch.v2.experimental import propagate_output_encodings
from aimet_torch.v2.nn import BaseQuantizationMixin, QuantizedConv2d
from aimet_torch.v2.nn import BaseQuantizationMixin, QuantizationMixin, QuantizedConv2d
import aimet_torch.v2.nn.modules.custom as custom
from ..models_ import test_models

Expand Down Expand Up @@ -938,6 +938,110 @@ def test_export_to_onnx_direct_fixed_param_names(self):
if 'bias' not in name:
assert name in param_encodings_set

def test_non_leaf_qmodule(self):
"""
Given: Define a quantized definition of a non-leaf module
"""
class CustomLinear(torch.nn.Module):
""" custom linear module """
def __init__(self, in_features, out_features):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))
self.bias = torch.nn.Parameter(torch.randn(out_features))
self.matmul = custom.MatMul()
self.add = custom.Add()

def forward(self, x):
x = self.matmul(x, self.weight.transpose(0, 1))
return self.add(x, self.bias)

@QuantizationMixin.implements(CustomLinear)
class QuantizedCustomLinear(QuantizationMixin, CustomLinear):
def __quant_init__(self):
super().__quant_init__()
self.input_quantizers = torch.nn.ModuleList([])
self.output_quantizers = torch.nn.ModuleList([])

def forward(self, x):
with self._patch_quantized_parameters():
return super().forward(x)

"""
When: Create quantsim with the non-leaf module
Then: 1) The non-leaf module should be converted to a quantized module
2) All its submodules should be also converted to quantized modules
"""
model = torch.nn.Sequential(
CustomLinear(10, 10),
torch.nn.Sigmoid(),
)
dummy_input = torch.randn(10, 10)

sim = QuantizationSimModel(model, dummy_input)

qlinear = sim.model[0]
assert isinstance(qlinear, QuantizedCustomLinear)
assert isinstance(qlinear.param_quantizers['weight'], AffineQuantizerBase)
assert qlinear.param_quantizers['bias'] is None

assert isinstance(qlinear.matmul, custom.QuantizedMatMul)
assert isinstance(qlinear.matmul.input_quantizers[0], AffineQuantizerBase)
assert qlinear.matmul.input_quantizers[1] is None
assert isinstance(qlinear.matmul.output_quantizers[0], AffineQuantizerBase)

assert isinstance(qlinear.add, custom.QuantizedAdd)
assert qlinear.add.input_quantizers[0] is None
assert isinstance(qlinear.add.input_quantizers[1], AffineQuantizerBase)
assert isinstance(qlinear.add.output_quantizers[0], AffineQuantizerBase)

def test_already_quantized_model(self):
"""
Given: The model already consists of quantized modules
When: Create quantsim with the model
Then: Throw runtime error
"""
model = torch.nn.Sequential(
QuantizedConv2d(3, 3, 3),
torch.nn.ReLU(),
)
dummy_input = torch.randn(1, 3, 224, 224)

with pytest.raises(RuntimeError):
_ = QuantizationSimModel(model, dummy_input)

"""
Given: The model already consists of quantizers
When: Create quantsim with the model
Then: Throw runtime error
"""
model = torch.nn.Sequential(
torch.nn.Conv2d(3, 3, 3),
QuantizeDequantize((), 0, 255, False),
)

with pytest.raises(RuntimeError):
_ = QuantizationSimModel(model, dummy_input)

"""
Given: The model itself is a quantized module
When: Create quantsim with the model
Then: Throw runtime error
"""
model = QuantizedConv2d(3, 3, 3)

with pytest.raises(RuntimeError):
_ = QuantizationSimModel(model, dummy_input)

"""
Given: The model itself is a quantizer
When: Create quantsim with the model
Then: Throw runtime error
"""
model = QuantizeDequantize((), 0, 255, False)

with pytest.raises(RuntimeError):
_ = QuantizationSimModel(model, dummy_input)


class TestQuantsimUtilities:

Expand Down

0 comments on commit b3bc4b2

Please sign in to comment.