diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py index dae142b0118..7d2880cf09a 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py @@ -895,9 +895,7 @@ def _set_src_qtzr(x: Product, consumer: Op, src_qtzr): _, out_qtzr, __ = self.get_op_quantizers(op) if not out_qtzr: - msg = 'Encoding propagation is only supported for ops with exactly ' \ - '1 output quantizer, but found output_quantizers[0] == []' - raise RuntimeError(msg) + continue if len(out_qtzr) != 1: msg = 'Encoding propagation is only supported for ops with exactly ' \ diff --git a/TrainingExtensions/onnx/test/python/models/models_for_tests.py b/TrainingExtensions/onnx/test/python/models/models_for_tests.py index 17f05baa7dc..3f0209039c0 100644 --- a/TrainingExtensions/onnx/test/python/models/models_for_tests.py +++ b/TrainingExtensions/onnx/test/python/models/models_for_tests.py @@ -2543,3 +2543,37 @@ def batchnorm_model_constants(): onnx.checker.check_model(model, True) return model +def integer_concat_model(): + model = helper.make_model( + graph=helper.make_graph( + name='IntConcatModel', + inputs=[helper.make_tensor_value_info('model_input', TensorProto.FLOAT, shape=[10, 10])], + outputs=[helper.make_tensor_value_info('model_output', TensorProto.FLOAT, shape=[10, 10, 1])], + initializer=[], + nodes=[ + helper.make_node('Constant', inputs=[], outputs=["constant_one"], + value=numpy_helper.from_array(np.array([1]).astype('int64')), name='one'), + helper.make_node( + 'Shape', + inputs=['model_input'], + outputs=['input_shape'], + name='shape' + ), + helper.make_node( + 'Concat', + inputs=['input_shape', 'constant_one'], + outputs=['output_shape'], + name='out_shape', + axis=0 + ), + helper.make_node( + "Reshape", + inputs=["model_input", "output_shape"], + outputs=["model_output"] + ) + ] + ) + ) + onnx.checker.check_model(model, True) + return model + diff --git a/TrainingExtensions/onnx/test/python/test_quantsim.py b/TrainingExtensions/onnx/test/python/test_quantsim.py index f5801d4abbe..3f67dc39593 100644 --- a/TrainingExtensions/onnx/test/python/test_quantsim.py +++ b/TrainingExtensions/onnx/test/python/test_quantsim.py @@ -1579,6 +1579,15 @@ def forward(self, x): _, out_qtzr, __ = sim.get_op_quantizers(cg_op) assert _compare_encodings(out_qtzr[0].encodings[0], sim.qc_quantize_op_dict['output'].encodings[0]) + def test_integer_concat(self): + """ + When: Model contains unquantizable layers with op_type in quantsim.op_types_to_tie_qtzrs + Then: Error should not be thrown during quantsim init + """ + model = models_for_tests.integer_concat_model() + with _apply_constraints(True): + sim = QuantizationSimModel(model) + def test_clamp_activation_encodings(self): model = models_for_tests.matmul_add_model() dummy_input = {'model_input': np.expand_dims(np.identity(8, np.float32), axis=(0, 1))}