diff --git a/TrainingExtensions/torch/test/python/test_tensor_quantizer.py b/TrainingExtensions/torch/test/python/test_tensor_quantizer.py index 1917944d190..4fa19f7bb0c 100644 --- a/TrainingExtensions/torch/test/python/test_tensor_quantizer.py +++ b/TrainingExtensions/torch/test/python/test_tensor_quantizer.py @@ -402,7 +402,7 @@ def test_learned_grid_encoding_getter(self, dtype): quant_wrapper = LearnedGridQuantWrapper(conv, round_mode=libpymo.RoundingMode.ROUND_NEAREST, quant_scheme=QuantScheme.training_range_learning_with_tf_init, is_output_quantized=True, activation_bw=16, - weight_bw=8, device="cuda:0") + weight_bw=8, device="cpu") enc = libpymo.TfEncoding() enc.bw, enc.max, enc.min = 16, 0.4, -0.98