diff --git a/tests/keras_tests/function_tests/test_hmse_error_method.py b/tests/keras_tests/function_tests/test_hmse_error_method.py index c217b4c95..1b1fed1ce 100644 --- a/tests/keras_tests/function_tests/test_hmse_error_method.py +++ b/tests/keras_tests/function_tests/test_hmse_error_method.py @@ -165,44 +165,11 @@ def test_uniform_threshold_selection_hmse_per_tensor(self): self._verify_params_calculation_execution(RANGE_MAX) def test_threshold_selection_hmse_no_gptq(self): - self._setup_with_args(quant_method=mct.target_platform.QuantizationMethod.SYMMETRIC, per_channel=True, - running_gptq=False) - - def _verify_node_default_mse_error(node_type): - node = [n for n in self.graph.nodes if n.type == node_type] - self.assertTrue(len(node) == 1, f"Expecting exactly 1 {node_type} node in test model.") - node = node[0] - - kernel_attr_error_method = ( - node.candidates_quantization_cfg[0].weights_quantization_cfg.get_attr_config(KERNEL).weights_error_method) - self.assertTrue(kernel_attr_error_method == mct.core.QuantizationErrorMethod.MSE, - f"Expecting {node_type} node quantization parameter error method to be the default " - f"MSE when not running with GPTQ, but is set to {kernel_attr_error_method}.") - - # verifying that the nodes quantization params error method is changed to the default MSE - _verify_node_default_mse_error(layers.Conv2D) - _verify_node_default_mse_error(layers.Dense) - - calculate_quantization_params(self.graph, fw_impl=self.keras_impl, repr_data_gen_fn=representative_dataset, - hessian_info_service=self.his, num_hessian_samples=1) - - def _verify_node_no_hessian_computed(node_type): - node = [n for n in self.graph.nodes if n.type == node_type] - self.assertTrue(len(node) == 1, f"Expecting exactly 1 {node_type} node in test model.") - node = node[0] - - expected_hessian_request = HessianScoresRequest(mode=HessianMode.WEIGHTS, - granularity=HessianScoresGranularity.PER_ELEMENT, - data_loader=None, - n_samples=1, - target_nodes=[node]) - - with self.assertRaises(ValueError, msg='Not enough hessians are cached to fulfill the request') as e: - self.his.fetch_hessian(expected_hessian_request) - - # verifying that no Hessian scores were computed - _verify_node_no_hessian_computed(layers.Conv2D) - _verify_node_no_hessian_computed(layers.Dense) + with self.assertRaises(ValueError) as e: + self._setup_with_args(quant_method=mct.target_platform.QuantizationMethod.SYMMETRIC, per_channel=True, + running_gptq=False) + self.assertTrue('The HSE error method for parameters selection is only supported when running GPTQ ' + 'optimization due to long execution time that is not suitable for basic PTQ.' in e.exception.args[0]) def test_threshold_selection_hmse_no_kernel_attr(self): def _generate_bn_quantization_tpc(quant_method, per_channel): @@ -258,8 +225,9 @@ def _generate_bn_quantization_tpc(quant_method, per_channel): n_samples=1, target_nodes=[node]) - with self.assertRaises(ValueError, msg='Not enough hessians are cached to fulfill the request') as e: + with self.assertRaises(ValueError) as e: self.his.fetch_hessian(expected_hessian_request) + self.assertTrue('Not enough hessians are cached to fulfill the request' in e.exception.args[0]) if __name__ == '__main__':