Skip to content

Commit

Permalink
Fix broken test case
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu committed Feb 6, 2024
1 parent 89c07a5 commit 8065b99
Showing 1 changed file with 58 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@



@pytest.fixture
def enforce_target_dtype_bitwidth_config():
enforce_target_dtype_bitwidth_config = qsim_config.ENFORCE_TARGET_DTYPE_BITWIDTH_CONFIG
try:
qsim_config.ENFORCE_TARGET_DTYPE_BITWIDTH_CONFIG = True
yield
finally:
qsim_config.ENFORCE_TARGET_DTYPE_BITWIDTH_CONFIG = enforce_target_dtype_bitwidth_config


# pylint: disable=protected-access
# From https://github.com/quic/aimet/blob/b9cb122b57f591b8e62bb2bf48bb178151148011/TrainingExtensions/torch/test/python/test_quantsim_config.py#L76
Expand Down Expand Up @@ -1945,7 +1954,7 @@ def forward_pass(model, args):
if os.path.exists('./data/quantsim_config.json'):
os.remove('./data/quantsim_config.json')

def test_fp16_back_to_back_overrides(self):
def test_fp16_back_to_back_overrides(self, enforce_target_dtype_bitwidth_config):
"""
Test that activation tensors are set to fp16 as expected in case of standalone vs back to back.
"""
Expand Down Expand Up @@ -2049,42 +2058,60 @@ def forward(self, x1, x2, x3):

random_input = (torch.rand(1, 2), torch.rand(1, 2), torch.rand(1, 2))

qsim_config.ENFORCE_TARGET_DTYPE_BITWIDTH_CONFIG = True
sim = QuantizationSimModel(model, quant_scheme=QuantScheme.post_training_tf,
dummy_input=random_input, default_data_type=QuantizationDataType.int,
default_output_bw=8, default_param_bw=8,
config_file='./data/quantsim_config.json')
for name, module in sim.quant_wrappers():
if name == 'prelu2':
assert isinstance(module.input_quantizers[0], FloatQuantizeDequantize)
assert module.input_quantizers[0].bitwidth == 16
assert isinstance(module.output_quantizers[0], QuantizeDequantize)
assert module.output_quantizers[0].bitwidth == 8
elif name == 'relu1':
assert isinstance(module.input_quantizers[0], QuantizeDequantize)
assert module.input_quantizers[0].bitwidth == 8
assert isinstance(module.output_quantizers[0], QuantizeDequantize)
assert module.output_quantizers[0].bitwidth == 8
elif name == 'add1':
assert isinstance(module.input_quantizers[0], QuantizeDequantize)
assert module.input_quantizers[0].bitwidth == 8
assert isinstance(module.input_quantizers[1], QuantizeDequantize)
assert module.input_quantizers[1].bitwidth == 8
assert isinstance(module.output_quantizers[0], FloatQuantizeDequantize)
assert module.output_quantizers[0].bitwidth == 16
else:
for input_q in module.input_quantizers:
assert isinstance(input_q, FloatQuantizeDequantize)
assert input_q.bitwidth == 16
for output_q in module.output_quantizers:
assert isinstance(output_q, FloatQuantizeDequantize)
assert output_q.bitwidth == 16
for param_q in module.param_quantizers.values():
assert isinstance(param_q, FloatQuantizeDequantize)
assert param_q.bitwidth == 16
prelu1 = sim.model.prelu1
prelu2 = sim.model.prelu2
relu1 = sim.model.relu1
add1 = sim.model.add1
prelu3 = sim.model.prelu3
prelu4 = sim.model.prelu4
add2 = sim.model.add2

assert isinstance(prelu1.input_quantizers[0], FloatQuantizeDequantize)
assert prelu1.input_quantizers[0].is_float16()
assert isinstance(prelu1.output_quantizers[0], FloatQuantizeDequantize)
assert prelu1.output_quantizers[0].is_float16()
assert isinstance(prelu1.param_quantizers['weight'], FloatQuantizeDequantize)
assert prelu1.param_quantizers['weight'].is_float16()

assert prelu2.input_quantizers[0] is None
assert isinstance(prelu2.output_quantizers[0], QuantizeDequantize)
assert prelu2.output_quantizers[0].bitwidth == 8
assert isinstance(prelu2.param_quantizers['weight'], FloatQuantizeDequantize)
assert prelu2.param_quantizers['weight'].is_float16()

assert relu1.input_quantizers[0] is None
assert isinstance(relu1.output_quantizers[0], QuantizeDequantize)
assert relu1.output_quantizers[0].bitwidth == 8

assert add1.input_quantizers[0] is None
assert isinstance(add1.input_quantizers[1], QuantizeDequantize)
assert add1.input_quantizers[1].bitwidth == 8
assert isinstance(add1.output_quantizers[0], FloatQuantizeDequantize)
assert add1.output_quantizers[0].is_float16()

assert prelu3.input_quantizers[0] is None
assert isinstance(prelu3.output_quantizers[0], FloatQuantizeDequantize)
assert prelu3.output_quantizers[0].is_float16()
assert isinstance(prelu3.param_quantizers['weight'], FloatQuantizeDequantize)
assert prelu3.param_quantizers['weight'].is_float16()

assert isinstance(prelu4.input_quantizers[0], FloatQuantizeDequantize)
assert prelu4.input_quantizers[0].is_float16()
assert isinstance(prelu4.output_quantizers[0], FloatQuantizeDequantize)
assert prelu4.output_quantizers[0].is_float16()
assert isinstance(prelu4.param_quantizers['weight'], FloatQuantizeDequantize)
assert prelu4.param_quantizers['weight'].is_float16()

assert add2.input_quantizers[0] is None
assert add2.input_quantizers[1] is None
assert isinstance(add2.output_quantizers[0], FloatQuantizeDequantize)
assert add2.output_quantizers[0].is_float16()

# remove test config created
qsim_config.ENFORCE_TARGET_DTYPE_BITWIDTH_CONFIG = False
if os.path.exists('./data/quantsim_config.json'):
os.remove('./data/quantsim_config.json')

Expand Down

0 comments on commit 8065b99

Please sign in to comment.