Skip to content

Commit

Permalink
Fix tie_quantizer errors with unquantizable ops
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Tuttle <quic_mtuttle@quicinc.com>
  • Loading branch information
quic-mtuttle authored Nov 18, 2024
1 parent 85f0b13 commit 29fcb76
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
4 changes: 1 addition & 3 deletions TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,9 +895,7 @@ def _set_src_qtzr(x: Product, consumer: Op, src_qtzr):
_, out_qtzr, __ = self.get_op_quantizers(op)

if not out_qtzr:
msg = 'Encoding propagation is only supported for ops with exactly ' \
'1 output quantizer, but found output_quantizers[0] == []'
raise RuntimeError(msg)
continue

if len(out_qtzr) != 1:
msg = 'Encoding propagation is only supported for ops with exactly ' \
Expand Down
34 changes: 34 additions & 0 deletions TrainingExtensions/onnx/test/python/models/models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2543,3 +2543,37 @@ def batchnorm_model_constants():
onnx.checker.check_model(model, True)
return model

def integer_concat_model():
model = helper.make_model(
graph=helper.make_graph(
name='IntConcatModel',
inputs=[helper.make_tensor_value_info('model_input', TensorProto.FLOAT, shape=[10, 10])],
outputs=[helper.make_tensor_value_info('model_output', TensorProto.FLOAT, shape=[10, 10, 1])],
initializer=[],
nodes=[
helper.make_node('Constant', inputs=[], outputs=["constant_one"],
value=numpy_helper.from_array(np.array([1]).astype('int64')), name='one'),
helper.make_node(
'Shape',
inputs=['model_input'],
outputs=['input_shape'],
name='shape'
),
helper.make_node(
'Concat',
inputs=['input_shape', 'constant_one'],
outputs=['output_shape'],
name='out_shape',
axis=0
),
helper.make_node(
"Reshape",
inputs=["model_input", "output_shape"],
outputs=["model_output"]
)
]
)
)
onnx.checker.check_model(model, True)
return model

9 changes: 9 additions & 0 deletions TrainingExtensions/onnx/test/python/test_quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,15 @@ def forward(self, x):
_, out_qtzr, __ = sim.get_op_quantizers(cg_op)
assert _compare_encodings(out_qtzr[0].encodings[0], sim.qc_quantize_op_dict['output'].encodings[0])

def test_integer_concat(self):
"""
When: Model contains unquantizable layers with op_type in quantsim.op_types_to_tie_qtzrs
Then: Error should not be thrown during quantsim init
"""
model = models_for_tests.integer_concat_model()
with _apply_constraints(True):
sim = QuantizationSimModel(model)

def test_clamp_activation_encodings(self):
model = models_for_tests.matmul_add_model()
dummy_input = {'model_input': np.expand_dims(np.identity(8, np.float32), axis=(0, 1))}
Expand Down

0 comments on commit 29fcb76

Please sign in to comment.