Skip to content

Commit

Permalink
Prevent exception rule from applying to float quantization
Browse files Browse the repository at this point in the history
Signed-off-by: Geunho Lee <quic_geunlee@quicinc.com>
  • Loading branch information
quic-geunlee authored Nov 20, 2024
1 parent 32e54d5 commit 7870746
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2050,6 +2050,15 @@ def _apply_exception_rules(self):
target_quantizer_for_first_input = self._get_target_quantizer(first_input_quantizer, first_input_op, module_to_quant_wrapper)
target_quantizer_for_second_input = self._get_target_quantizer(second_input_quantizer, second_input_op, module_to_quant_wrapper)

# We don't need to apply exception rule when both first and second inputs are FP quantization
if (
target_quantizer_for_first_input
and target_quantizer_for_first_input.data_type == QuantizationDataType.float
and target_quantizer_for_second_input
and target_quantizer_for_second_input.data_type == QuantizationDataType.float
):
continue

# According to opdef for Matmul in HTP:
# 16bit Weight(second input for dynamic MatMul) must have 16bit Activation(first input for dynamic MatMul).
# 16bit Activation and 16bit Weight require minimum arch V73.
Expand Down
161 changes: 161 additions & 0 deletions TrainingExtensions/torch/test/python/v2/ab_test/test_quantizer_.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from torchvision import models

from aimet_common.defs import QuantScheme, QuantizationDataType, MAP_ROUND_MODE_TO_PYMO
from aimet_common import quantsim as common_quantsim
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_common.utils import AimetLogger
from aimet_torch import onnx_utils
Expand Down Expand Up @@ -3895,3 +3896,163 @@ def test_exception_for_matmul_edge_case(
assert closest_output_quantizer_of_second_input.symmetric
else:
assert not closest_output_quantizer_of_second_input.symmetric

@pytest.mark.parametrize(
'default_data_type', [QuantizationDataType.int, QuantizationDataType.float]
)
def test_exception_rule_for_matmul_quantization_data_types(self, default_data_type):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = test_models.ModelWithMatMul2().to(device)
dummy_input = (
torch.randn(10, 3, 4, device=device),
torch.randn(10, 5, 4, device=device),
)

quantsim_config = {
'defaults': {
'hw_version': 'V69',
'ops': {'is_output_quantized': 'True'},
'params': {},
},
'params': {},
'op_type': {
'Relu': {'is_output_quantized': 'False'},
},
'supergroups': [],
'model_input': {},
'model_output': {},
}

with tempfile.TemporaryDirectory() as temp_dir:
config_path = os.path.join(temp_dir, 'quantsim_config.json')

with open(config_path, 'w') as f:
json.dump(quantsim_config, f)

sim = QuantizationSimModel(
model,
dummy_input,
config_file=config_path,
default_output_bw=16,
default_param_bw=16,
default_data_type=default_data_type,
)

if default_data_type == QuantizationDataType.float:
assert sim.model.act3.output_quantizers[0].bitwidth == 16
else: # default_data_type == QuantizationDataType.int
# Second input of MatMul should be symmetric and 8bit if hw_version < V73
assert sim.model.act3.output_quantizers[0].bitwidth == 8
assert sim.model.act3.output_quantizers[0].symmetric


def test_exception_rule_for_group_norm_float_quantization(self):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = test_models.ModelWithGroupNorm().to(device)
dummy_input = torch.randn((1, 6, 2, 2), device=device)

quantsim_config = {
'defaults': {
'hw_version': 'V79',
'ops': {'is_output_quantized': 'True'},
'params': {'is_symmetric': 'True', 'is_quantized': 'True'},
},
'params': {},
'op_type': {
'GroupNorm': {
'per_channel_quantization': 'False',
'params': {'bias': {'is_quantized': 'True'}},
},
},
'supergroups': [],
'model_input': {},
'model_output': {},
}

common_quantsim.ALLOW_EXPERIMENTAL = True
with tempfile.TemporaryDirectory() as temp_dir:
config_path = os.path.join(temp_dir, 'quantsim_config.json')
with open(config_path, 'w') as f:
json.dump(quantsim_config, f)

sim = QuantizationSimModel(
model,
dummy_input,
default_param_bw=8,
default_output_bw=16,
config_file=config_path,
default_data_type=QuantizationDataType.float,
)

sim.compute_encodings(
lambda sim_model, _: sim_model(dummy_input), forward_pass_callback_args=None
)

wrapper = sim.model.gn
weight_quantizer = wrapper.param_quantizers['weight']
bias_quantizer = wrapper.param_quantizers['bias']
assert weight_quantizer
assert bias_quantizer

# In QNN, the weight tensor is treated as an activation tensor in the graph and should be matched with the activation quantizer setting
assert isinstance(weight_quantizer, FloatQuantizeDequantize)
assert weight_quantizer.bitwidth == 16
assert isinstance(bias_quantizer, FloatQuantizeDequantize)
assert bias_quantizer.bitwidth == 16
common_quantsim.ALLOW_EXPERIMENTAL = False


def test_exception_for_embedding_float_quantization(self):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = test_models.ModelWithEmbedding().to(device)
dummy_input = torch.tensor(
[[1, 4, 2, 5], [4, 3, 2, 7]], dtype=torch.int64, device=device
)

quantsim_config = {
'defaults': {
'hw_version': 'V79',
'ops': {'is_output_quantized': 'True'},
'params': {'is_symmetric': 'True', 'is_quantized': 'True'},
},
'params': {},
'op_type': {
'Gather': {
'is_output_quantized': 'False',
'per_channel_quantization': 'False',
},
},
'supergroups': [],
'model_input': {},
'model_output': {},
}

common_quantsim.ALLOW_EXPERIMENTAL = True
with tempfile.TemporaryDirectory() as temp_dir:
config_path = os.path.join(temp_dir, 'quantsim_config.json')
with open(config_path, 'w') as f:
json.dump(quantsim_config, f)

sim = QuantizationSimModel(
model,
dummy_input,
default_param_bw=8,
default_output_bw=16,
config_file=config_path,
default_data_type=QuantizationDataType.float,
)

sim.compute_encodings(
lambda sim_model, _: sim_model(dummy_input), forward_pass_callback_args=None
)

qembedding = sim.model.embedding
weight_quantizer = qembedding.param_quantizers['weight']

# In QNN, the weight tensor is treated as an activation tensor in the graph and should be matched with the activation quantizer setting
assert isinstance(weight_quantizer, FloatQuantizeDequantize)
assert weight_quantizer.bitwidth == 16
common_quantsim.ALLOW_EXPERIMENTAL = False

0 comments on commit 7870746

Please sign in to comment.