diff --git a/TrainingExtensions/torch/test/python/experimental/v2/ab_test/test_quantsim_config_.py b/TrainingExtensions/torch/test/python/experimental/v2/ab_test/test_quantsim_config_.py index 12921fded74..ed4fd6b2ad7 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/ab_test/test_quantsim_config_.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/ab_test/test_quantsim_config_.py @@ -66,6 +66,15 @@ +@pytest.fixture +def enforce_target_dtype_bitwidth_config(): + enforce_target_dtype_bitwidth_config = qsim_config.ENFORCE_TARGET_DTYPE_BITWIDTH_CONFIG + try: + qsim_config.ENFORCE_TARGET_DTYPE_BITWIDTH_CONFIG = True + yield + finally: + qsim_config.ENFORCE_TARGET_DTYPE_BITWIDTH_CONFIG = enforce_target_dtype_bitwidth_config + # pylint: disable=protected-access # From https://github.com/quic/aimet/blob/b9cb122b57f591b8e62bb2bf48bb178151148011/TrainingExtensions/torch/test/python/test_quantsim_config.py#L76 @@ -1945,7 +1954,7 @@ def forward_pass(model, args): if os.path.exists('./data/quantsim_config.json'): os.remove('./data/quantsim_config.json') - def test_fp16_back_to_back_overrides(self): + def test_fp16_back_to_back_overrides(self, enforce_target_dtype_bitwidth_config): """ Test that activation tensors are set to fp16 as expected in case of standalone vs back to back. """ @@ -2049,42 +2058,60 @@ def forward(self, x1, x2, x3): random_input = (torch.rand(1, 2), torch.rand(1, 2), torch.rand(1, 2)) - qsim_config.ENFORCE_TARGET_DTYPE_BITWIDTH_CONFIG = True sim = QuantizationSimModel(model, quant_scheme=QuantScheme.post_training_tf, dummy_input=random_input, default_data_type=QuantizationDataType.int, default_output_bw=8, default_param_bw=8, config_file='./data/quantsim_config.json') - for name, module in sim.quant_wrappers(): - if name == 'prelu2': - assert isinstance(module.input_quantizers[0], FloatQuantizeDequantize) - assert module.input_quantizers[0].bitwidth == 16 - assert isinstance(module.output_quantizers[0], QuantizeDequantize) - assert module.output_quantizers[0].bitwidth == 8 - elif name == 'relu1': - assert isinstance(module.input_quantizers[0], QuantizeDequantize) - assert module.input_quantizers[0].bitwidth == 8 - assert isinstance(module.output_quantizers[0], QuantizeDequantize) - assert module.output_quantizers[0].bitwidth == 8 - elif name == 'add1': - assert isinstance(module.input_quantizers[0], QuantizeDequantize) - assert module.input_quantizers[0].bitwidth == 8 - assert isinstance(module.input_quantizers[1], QuantizeDequantize) - assert module.input_quantizers[1].bitwidth == 8 - assert isinstance(module.output_quantizers[0], FloatQuantizeDequantize) - assert module.output_quantizers[0].bitwidth == 16 - else: - for input_q in module.input_quantizers: - assert isinstance(input_q, FloatQuantizeDequantize) - assert input_q.bitwidth == 16 - for output_q in module.output_quantizers: - assert isinstance(output_q, FloatQuantizeDequantize) - assert output_q.bitwidth == 16 - for param_q in module.param_quantizers.values(): - assert isinstance(param_q, FloatQuantizeDequantize) - assert param_q.bitwidth == 16 + prelu1 = sim.model.prelu1 + prelu2 = sim.model.prelu2 + relu1 = sim.model.relu1 + add1 = sim.model.add1 + prelu3 = sim.model.prelu3 + prelu4 = sim.model.prelu4 + add2 = sim.model.add2 + + assert isinstance(prelu1.input_quantizers[0], FloatQuantizeDequantize) + assert prelu1.input_quantizers[0].is_float16() + assert isinstance(prelu1.output_quantizers[0], FloatQuantizeDequantize) + assert prelu1.output_quantizers[0].is_float16() + assert isinstance(prelu1.param_quantizers['weight'], FloatQuantizeDequantize) + assert prelu1.param_quantizers['weight'].is_float16() + + assert prelu2.input_quantizers[0] is None + assert isinstance(prelu2.output_quantizers[0], QuantizeDequantize) + assert prelu2.output_quantizers[0].bitwidth == 8 + assert isinstance(prelu2.param_quantizers['weight'], FloatQuantizeDequantize) + assert prelu2.param_quantizers['weight'].is_float16() + + assert relu1.input_quantizers[0] is None + assert isinstance(relu1.output_quantizers[0], QuantizeDequantize) + assert relu1.output_quantizers[0].bitwidth == 8 + + assert add1.input_quantizers[0] is None + assert isinstance(add1.input_quantizers[1], QuantizeDequantize) + assert add1.input_quantizers[1].bitwidth == 8 + assert isinstance(add1.output_quantizers[0], FloatQuantizeDequantize) + assert add1.output_quantizers[0].is_float16() + + assert prelu3.input_quantizers[0] is None + assert isinstance(prelu3.output_quantizers[0], FloatQuantizeDequantize) + assert prelu3.output_quantizers[0].is_float16() + assert isinstance(prelu3.param_quantizers['weight'], FloatQuantizeDequantize) + assert prelu3.param_quantizers['weight'].is_float16() + + assert isinstance(prelu4.input_quantizers[0], FloatQuantizeDequantize) + assert prelu4.input_quantizers[0].is_float16() + assert isinstance(prelu4.output_quantizers[0], FloatQuantizeDequantize) + assert prelu4.output_quantizers[0].is_float16() + assert isinstance(prelu4.param_quantizers['weight'], FloatQuantizeDequantize) + assert prelu4.param_quantizers['weight'].is_float16() + + assert add2.input_quantizers[0] is None + assert add2.input_quantizers[1] is None + assert isinstance(add2.output_quantizers[0], FloatQuantizeDequantize) + assert add2.output_quantizers[0].is_float16() # remove test config created - qsim_config.ENFORCE_TARGET_DTYPE_BITWIDTH_CONFIG = False if os.path.exists('./data/quantsim_config.json'): os.remove('./data/quantsim_config.json')