diff --git a/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py b/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py index 8b29f111e..e97b50531 100644 --- a/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +++ b/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py @@ -197,7 +197,7 @@ def __eq__(self, other): self.simd_size == other.simd_size -class QuantizationConfigOptions(object): +class QuantizationConfigOptions: """ Wrap a set of quantization configurations to consider during the quantization @@ -215,19 +215,24 @@ def __init__(self, """ assert isinstance(quantization_config_list, - list), f'\'QuantizationConfigOptions\' options list must be a list, but received: {type(quantization_config_list)}.' - assert len(quantization_config_list) > 0, f'Options list can not be empty.' + list), f"'QuantizationConfigOptions' options list must be a list, but received: {type(quantization_config_list)}." for cfg in quantization_config_list: - assert isinstance(cfg, OpQuantizationConfig), f'Each option must be an instance of \'OpQuantizationConfig\', but found an object of type: {type(cfg)}.' + assert isinstance(cfg, OpQuantizationConfig),\ + f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}." self.quantization_config_list = quantization_config_list if len(quantization_config_list) > 1: - assert base_config is not None, f'For multiple configurations, a \'base_config\' is required for non-mixed-precision optimization.' - assert base_config in quantization_config_list, f"\'base_config\' must be included in the quantization config options list." + assert base_config is not None, \ + f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization." + assert any([base_config is cfg for cfg in quantization_config_list]), \ + f"'base_config' must be included in the quantization config options list." + # Enforce base_config to be a reference to an instance in quantization_config_list. self.base_config = base_config elif len(quantization_config_list) == 1: + assert base_config is None or base_config == quantization_config_list[0], "'base_config' should be included in 'quantization_config_list'" + # Set base_config to be a reference to the first instance in quantization_config_list. self.base_config = quantization_config_list[0] else: - Logger.critical("\'QuantizationConfigOptions\' requires at least one \'OpQuantizationConfig\'; the provided list is empty.") + raise AssertionError("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.") def __eq__(self, other): """ diff --git a/tests/common_tests/helpers/generate_test_tp_model.py b/tests/common_tests/helpers/generate_test_tp_model.py index 1cb0fd9d7..b1bd89e66 100644 --- a/tests/common_tests/helpers/generate_test_tp_model.py +++ b/tests/common_tests/helpers/generate_test_tp_model.py @@ -66,7 +66,11 @@ def generate_mixed_precision_test_tp_model(base_cfg, default_config, mp_bitwidth candidate_cfg = base_cfg.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: weights_n_bits}}, activation_n_bits=activation_n_bits) - mp_op_cfg_list.append(candidate_cfg) + if candidate_cfg == base_cfg: + # the base config must be a reference of an instance in the cfg_list, so we put it and not the clone in the list. + mp_op_cfg_list.append(base_cfg) + else: + mp_op_cfg_list.append(candidate_cfg) return generate_tp_model(default_config=default_config, base_config=base_cfg, @@ -85,8 +89,11 @@ def generate_tp_model_with_activation_mp(base_cfg, default_config, mp_bitwidth_c **{k: v for k, v in base_cfg.attr_weights_configs_mapping.items() if k != KERNEL_ATTR}}, activation_n_bits=activation_n_bits) - - mp_op_cfg_list.append(candidate_cfg) + if candidate_cfg == base_cfg: + # the base config must be a reference of an instance in the cfg_list, so we put it and not the clone in the list. + mp_op_cfg_list.append(base_cfg) + else: + mp_op_cfg_list.append(candidate_cfg) base_tp_model = generate_tp_model(default_config=default_config, base_config=base_cfg, diff --git a/tests/common_tests/test_tp_model.py b/tests/common_tests/test_tp_model.py index cf4a1a510..3a5683a55 100644 --- a/tests/common_tests/test_tp_model.py +++ b/tests/common_tests/test_tp_model.py @@ -96,7 +96,8 @@ class QCOptionsTest(unittest.TestCase): def test_empty_qc_options(self): with self.assertRaises(AssertionError) as e: tp.QuantizationConfigOptions([]) - self.assertEqual('Options list can not be empty.', str(e.exception)) + self.assertEqual("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.", + str(e.exception)) def test_list_of_no_qc(self): with self.assertRaises(AssertionError) as e: diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/bias_correction_dw_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/bias_correction_dw_test.py index f7f142058..88f678b2f 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/bias_correction_dw_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/bias_correction_dw_test.py @@ -56,6 +56,4 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= error = np.sum(error, axis=(0,1)).flatten() bias = dw_layer.weights[2] # Input mean is 1 so correction_term = quant_error * 1 - # TODO: - # Increase atol due to a minor difference in Symmetric quantizer - self.unit_test.assertTrue(np.isclose(error, bias, atol=1e-7).all()) + self.unit_test.assertTrue(np.isclose(error, bias.numpy(), atol=3e-7).all())