Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix ignored hmse after resource utilization computation #1253

Merged
merged 4 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@ class MpDistanceWeighting(Enum):

def __call__(self, distance_matrix: np.ndarray) -> np.ndarray:
return self.value(distance_matrix)

def __deepcopy__(self, memo):
return self
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import copy

import numpy as np
from typing import Callable, Any, Dict, Tuple

from model_compression_toolkit.constants import FLOAT_BITWIDTH, BITS_TO_BYTES
from model_compression_toolkit.core import FrameworkInfo, ResourceUtilization, CoreConfig
from model_compression_toolkit.core import FrameworkInfo, ResourceUtilization, CoreConfig, QuantizationErrorMethod
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
Expand Down Expand Up @@ -57,7 +59,7 @@ def compute_resource_utilization_data(in_model: Any,


"""

core_config = _create_core_config_for_ru(core_config)
# We assume that the resource_utilization_data API is used to compute the model resource utilization for
# mixed precision scenario, so we run graph preparation under the assumption of enabled mixed precision.
if transformed_graph is None:
Expand Down Expand Up @@ -222,6 +224,8 @@ def requires_mixed_precision(in_model: Any,
Returns: A boolean indicating if mixed precision is needed.
"""
is_mixed_precision = False
core_config = _create_core_config_for_ru(core_config)

transformed_graph = graph_preparation_runner(in_model,
representative_data_gen,
core_config.quantization_config,
Expand All @@ -247,3 +251,21 @@ def requires_mixed_precision(in_model: Any,
is_mixed_precision |= target_resource_utilization.total_memory < total_weights_memory_bytes + max_activation_tensor_size_bytes
is_mixed_precision |= target_resource_utilization.bops < bops_count
return is_mixed_precision


def _create_core_config_for_ru(core_config: CoreConfig) -> CoreConfig:
"""
Create a core config to use for resource utilization computation.

Args:
core_config: input core config

Returns:
Core config for resource utilization.
"""
core_config = copy.deepcopy(core_config)
# For resource utilization graph_preparation_runner runs with gptq=False (the default value). HMSE is not supported
# without GPTQ and will raise an error later so we replace it with MSE.
if core_config.quantization_config.weights_error_method == QuantizationErrorMethod.HMSE:
core_config.quantization_config.weights_error_method = QuantizationErrorMethod.MSE
return core_config
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,10 @@ def set_quantization_configuration_to_graph(graph: Graph,

if quant_config.weights_error_method == QuantizationErrorMethod.HMSE:
if not running_gptq:
Logger.warning(f"The HMSE error method for parameters selection is only supported when running GPTQ "
f"optimization due to long execution time that is not suitable for basic PTQ. "
f"Using the default MSE error method instead.")
quant_config.weights_error_method = QuantizationErrorMethod.MSE
else:
Logger.warning("Using the HMSE error method for weights quantization parameters search. "
"Note: This method may significantly increase runtime during the parameter search process.")
raise ValueError(f"The HMSE error method for parameters selection is only supported when running GPTQ "
f"optimization due to long execution time that is not suitable for basic PTQ.")
Logger.warning("Using the HMSE error method for weights quantization parameters search. "
"Note: This method may significantly increase runtime during the parameter search process.")

nodes_to_manipulate_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_bit_widths(graph)

Expand Down
46 changes: 7 additions & 39 deletions tests/keras_tests/function_tests/test_hmse_error_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,44 +165,11 @@ def test_uniform_threshold_selection_hmse_per_tensor(self):
self._verify_params_calculation_execution(RANGE_MAX)

def test_threshold_selection_hmse_no_gptq(self):
self._setup_with_args(quant_method=mct.target_platform.QuantizationMethod.SYMMETRIC, per_channel=True,
running_gptq=False)

def _verify_node_default_mse_error(node_type):
node = [n for n in self.graph.nodes if n.type == node_type]
self.assertTrue(len(node) == 1, f"Expecting exactly 1 {node_type} node in test model.")
node = node[0]

kernel_attr_error_method = (
node.candidates_quantization_cfg[0].weights_quantization_cfg.get_attr_config(KERNEL).weights_error_method)
self.assertTrue(kernel_attr_error_method == mct.core.QuantizationErrorMethod.MSE,
f"Expecting {node_type} node quantization parameter error method to be the default "
f"MSE when not running with GPTQ, but is set to {kernel_attr_error_method}.")

# verifying that the nodes quantization params error method is changed to the default MSE
_verify_node_default_mse_error(layers.Conv2D)
_verify_node_default_mse_error(layers.Dense)

calculate_quantization_params(self.graph, fw_impl=self.keras_impl, repr_data_gen_fn=representative_dataset,
hessian_info_service=self.his, num_hessian_samples=1)

def _verify_node_no_hessian_computed(node_type):
node = [n for n in self.graph.nodes if n.type == node_type]
self.assertTrue(len(node) == 1, f"Expecting exactly 1 {node_type} node in test model.")
node = node[0]

expected_hessian_request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
granularity=HessianScoresGranularity.PER_ELEMENT,
data_loader=None,
n_samples=1,
target_nodes=[node])

with self.assertRaises(ValueError, msg='Not enough hessians are cached to fulfill the request') as e:
self.his.fetch_hessian(expected_hessian_request)

# verifying that no Hessian scores were computed
_verify_node_no_hessian_computed(layers.Conv2D)
_verify_node_no_hessian_computed(layers.Dense)
with self.assertRaises(ValueError) as e:
self._setup_with_args(quant_method=mct.target_platform.QuantizationMethod.SYMMETRIC, per_channel=True,
running_gptq=False)
self.assertTrue('The HMSE error method for parameters selection is only supported when running GPTQ '
'optimization due to long execution time that is not suitable for basic PTQ.' in e.exception.args[0])

def test_threshold_selection_hmse_no_kernel_attr(self):
def _generate_bn_quantization_tpc(quant_method, per_channel):
Expand Down Expand Up @@ -258,8 +225,9 @@ def _generate_bn_quantization_tpc(quant_method, per_channel):
n_samples=1,
target_nodes=[node])

with self.assertRaises(ValueError, msg='Not enough hessians are cached to fulfill the request') as e:
with self.assertRaises(ValueError) as e:
self.his.fetch_hessian(expected_hessian_request)
self.assertTrue('Not enough hessians are cached to fulfill the request' in e.exception.args[0])


if __name__ == '__main__':
Expand Down
Loading