diff --git a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py index e960c021576..f4c8c7af91c 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py @@ -765,7 +765,12 @@ def replace_wrappers_for_quantize_dequantize(self): """ if self._quant_scheme == QuantScheme.training_range_learning_with_tf_init or self._quant_scheme == \ QuantScheme.training_range_learning_with_tf_enhanced_init: - device = utils.get_device(self.model) + try: + device = utils.get_device(self.model) + except StopIteration: + # Model doesn't have any parameter. + # Set device to cpu by default. + device = torch.device('cpu') self._replace_quantization_wrapper(self.model, device) diff --git a/TrainingExtensions/torch/test/python/experimental/v2/ab_test/test_quantsim_logits.py b/TrainingExtensions/torch/test/python/experimental/v2/ab_test/test_quantsim_logits.py index b6aac9e1465..69c4326481e 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/ab_test/test_quantsim_logits.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/ab_test/test_quantsim_logits.py @@ -198,6 +198,7 @@ def set_seed(seed): @pytest.mark.parametrize('quant_scheme', [QuantScheme.post_training_tf, + QuantScheme.training_range_learning_with_tf_init, # QuantScheme.post_training_percentile, # TODO: not implemented # QuantScheme.training_range_learning_with_tf_init, # TODO: not implemented ]) @@ -288,4 +289,3 @@ def test_multi_output(self, quant_scheme, seed): model = models_to_test.ModelWith5Output() dummy_input = torch.randn(1, 3, 224, 224) self.check_qsim_logit_consistency(CONFIG_DEFAULT, quant_scheme, model, dummy_input) -