diff --git a/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py b/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py index 920b720f9..6c9ee9670 100644 --- a/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +++ b/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py @@ -302,7 +302,11 @@ def _run_operation(self, # Build a functional node using its args if isinstance(n, FunctionalNode): if n.inputs_as_list: # If the first argument should be a list of tensors: - out_tensors_of_n_float = op_func(input_tensors, *n.op_call_args, **op_call_kwargs) + if isinstance(op_func, KerasQuantizationWrapper): + # in wrapped nodes, the op args & kwargs are already in the KerasQuantizationWrapper. + out_tensors_of_n_float = op_func(input_tensors) + else: + out_tensors_of_n_float = op_func(input_tensors, *n.op_call_args, **op_call_kwargs) else: # If the input tensors should not be a list but iterated: out_tensors_of_n_float = op_func(*input_tensors, *n.op_call_args, **op_call_kwargs) else: diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py index 424bccd17..e92c51fd7 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py @@ -207,7 +207,7 @@ def generate_tp_model(default_config: OpQuantizationConfig, const_config_input16_per_tensor = const_config.clone_and_edit( supported_input_activation_n_bits=(8, 16), default_weight_attr_config=default_config.default_weight_attr_config.clone_and_edit( - enable_weights_quantization=True, weights_per_channel_threshold=True, + enable_weights_quantization=True, weights_per_channel_threshold=False, weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO) ) const_config_input16_output16_per_tensor = const_config_input16_per_tensor.clone_and_edit(