Skip to content

Commit

Permalink
Add A/B test against V1 range learning quantizer
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored and quic-akhobare committed Feb 2, 2024
1 parent ae77cf7 commit 1db6857
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
7 changes: 6 additions & 1 deletion TrainingExtensions/torch/src/python/aimet_torch/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,12 @@ def replace_wrappers_for_quantize_dequantize(self):
"""
if self._quant_scheme == QuantScheme.training_range_learning_with_tf_init or self._quant_scheme == \
QuantScheme.training_range_learning_with_tf_enhanced_init:
device = utils.get_device(self.model)
try:
device = utils.get_device(self.model)
except StopIteration:
# Model doesn't have any parameter.
# Set device to cpu by default.
device = torch.device('cpu')

self._replace_quantization_wrapper(self.model, device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def set_seed(seed):


@pytest.mark.parametrize('quant_scheme', [QuantScheme.post_training_tf,
QuantScheme.training_range_learning_with_tf_init,
# QuantScheme.post_training_percentile, # TODO: not implemented
# QuantScheme.training_range_learning_with_tf_init, # TODO: not implemented
])
Expand Down Expand Up @@ -288,4 +289,3 @@ def test_multi_output(self, quant_scheme, seed):
model = models_to_test.ModelWith5Output()
dummy_input = torch.randn(1, 3, 224, 224)
self.check_qsim_logit_consistency(CONFIG_DEFAULT, quant_scheme, model, dummy_input)

0 comments on commit 1db6857

Please sign in to comment.