From d13319f2148a3445a25beeb48ae4c3f013f2cdf7 Mon Sep 17 00:00:00 2001 From: Ofir Gordon Date: Wed, 5 Jun 2024 13:19:59 +0300 Subject: [PATCH] Activation Hessian computation runtime optimization (#1092) Improve Activation Hessian computation runtime for GPTQ and Mixed precision with the following optimizations: - Enable batch computation. - Enable computation on a set of nodes (instead of a single node only). - Other minor loop and implementation modifications. --------- Co-authored-by: Ofir Gordon --- model_compression_toolkit/constants.py | 5 +- .../common/hessian/hessian_info_service.py | 301 ++++++++++++++---- .../core/common/hessian/hessian_info_utils.py | 9 +- .../hessian/trace_hessian_calculator.py | 7 +- .../common/hessian/trace_hessian_request.py | 10 +- .../mixed_precision_quantization_config.py | 8 +- .../mixed_precision/sensitivity_evaluation.py | 56 +--- .../lfh_importance_metric.py | 6 +- .../error_functions.py | 2 +- ...tivation_trace_hessian_calculator_keras.py | 142 ++++----- .../weights_trace_hessian_calculator_keras.py | 42 ++- ...vation_trace_hessian_calculator_pytorch.py | 120 +++---- ...eights_trace_hessian_calculator_pytorch.py | 38 ++- model_compression_toolkit/core/runner.py | 3 +- .../gptq/common/gptq_config.py | 9 +- .../gptq/common/gptq_training.py | 46 +-- .../gptq/keras/quantization_facade.py | 11 +- .../gptq/pytorch/quantization_facade.py | 12 +- .../feature_networks/gptq/gptq_test.py | 8 +- .../weights_mixed_precision_tests.py | 7 +- .../test_features_runner.py | 4 +- .../function_tests/test_get_gptq_config.py | 102 +++--- .../test_hessian_info_calculator.py | 183 +++++------ .../function_tests/test_hessian_service.py | 135 ++++++-- .../function_tests/test_hmse_error_method.py | 11 +- ..._sensitivity_eval_non_suppoerted_output.py | 12 +- .../conv2d_conv2dtranspose_pruning_test.py | 3 +- .../conv2dtranspose_pruning_test.py | 3 +- .../pruning_keras_feature_test.py | 2 +- .../function_tests/get_gptq_config_test.py | 26 +- .../function_tests/test_function_runner.py | 20 +- .../test_hessian_info_calculator.py | 105 ++++-- .../function_tests/test_hessian_service.py | 297 +++++++++++++++++ ...t_sensitivity_eval_non_supported_output.py | 12 +- .../model_tests/feature_models/gptq_test.py | 28 +- .../pruning_pytorch_feature_test.py | 2 +- 36 files changed, 1210 insertions(+), 577 deletions(-) create mode 100644 tests/pytorch_tests/function_tests/test_hessian_service.py diff --git a/model_compression_toolkit/constants.py b/model_compression_toolkit/constants.py index d8c364c83..c90092009 100644 --- a/model_compression_toolkit/constants.py +++ b/model_compression_toolkit/constants.py @@ -122,8 +122,11 @@ # Hessian configuration default constants HESSIAN_OUTPUT_ALPHA = 0.3 -HESSIAN_NUM_ITERATIONS = 50 +HESSIAN_NUM_ITERATIONS = 100 HESSIAN_EPS = 1e-6 +ACT_HESSIAN_DEFAULT_BATCH_SIZE = 32 +GPTQ_HESSIAN_NUM_SAMPLES = 32 +MP_DEFAULT_NUM_SAMPLES = 32 # Pruning constants PRUNING_NUM_SCORE_APPROXIMATIONS = 32 \ No newline at end of file diff --git a/model_compression_toolkit/core/common/hessian/hessian_info_service.py b/model_compression_toolkit/core/common/hessian/hessian_info_service.py index b658a6420..6860a8777 100644 --- a/model_compression_toolkit/core/common/hessian/hessian_info_service.py +++ b/model_compression_toolkit/core/common/hessian/hessian_info_service.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from collections.abc import Iterable +import numpy as np from functools import partial -from typing import Callable, List +from typing import Callable, List, Dict, Any, Tuple from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS -from model_compression_toolkit.core.common.hessian.trace_hessian_request import TraceHessianRequest +from model_compression_toolkit.core.common.hessian.trace_hessian_request import TraceHessianRequest, \ + HessianInfoGranularity, HessianMode from model_compression_toolkit.logger import Logger @@ -27,7 +30,7 @@ class HessianInfoService: This class provides functionalities to compute approximation based on the Hessian matrix based on the different parameters (such as number of iterations for approximating the info) - and input images (from representative_dataset). + and input images (using representative_dataset_gen). It also offers cache management capabilities for efficient computation and retrieval. Note: @@ -38,49 +41,97 @@ class HessianInfoService: def __init__(self, graph, - representative_dataset: Callable, + representative_dataset_gen: Callable, fw_impl, - num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS - ): + num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS): """ Args: graph: Float graph. - representative_dataset: A callable that provides a dataset for sampling. + representative_dataset_gen: A callable that provides a dataset for sampling. fw_impl: Framework-specific implementation for trace Hessian approximation computation. """ self.graph = graph - # Create a representative_data_gen with batch size of 1 - self.representative_dataset = partial(self._sample_single_representative_dataset, - representative_dataset=representative_dataset) + self.representative_dataset_gen = representative_dataset_gen self.fw_impl = fw_impl self.num_iterations_for_approximation = num_iterations_for_approximation self.trace_hessian_request_to_score_list = {} - def _sample_single_representative_dataset(self, representative_dataset: Callable): + def _sample_batch_representative_dataset(self, + representative_dataset: Iterable, + num_hessian_samples: int, + num_inputs: int, + last_iter_remain_samples: List[List[np.ndarray]] = None + ) -> Tuple[List[np.ndarray], List[List[np.ndarray]]]: """ - Get a single sample (namely, batch size of 1) from a representative dataset. + Get a batch of samples from a representative dataset with the requested num_hessian_samples. Args: - representative_dataset: Callable which returns the representative dataset at any batch size. - - Returns: List of inputs from representative_dataset where each sample has a batch size of 1. + representative_dataset: A generator which yields batches of input samples. + num_hessian_samples: Number of requested samples to compute batch Hessian approximation scores. + num_inputs: Number of input layers of the model on which the scores are computed. + last_iter_remain_samples: A list of input samples (for each input layer) with remaining samples from + previous iterations. + + Returns: A tuple with two lists: + (1) A list of inputs - a tensor of the requested batch size for each input layer. + (2) A list of remaining samples - for each input layer. """ - images = next(representative_dataset()) - if not isinstance(images, list): - Logger.critical(f'Expected images to be a list; found type: {type(images)}.') - # Ensure each image is a single sample, if not, take the first sample - return [image[0:1, ...] if image.shape[0] != 1 else image for image in images] + if num_inputs < 0: # pragma: no cover + Logger.critical(f"Number of images to compute Hessian approximation must be positive, " + f"but given {num_inputs}.") + + all_inp_hessian_samples = [[] for _ in range(num_inputs)] + # Collect the requested number of samples from the representative dataset + for batch in representative_dataset: + if not isinstance(batch, list): + Logger.critical(f'Expected batch to be a list; found type: {type(batch)}.') # pragma: no cover + all_inp_remaining_samples = [[] for _ in range(num_inputs)] + for inp_idx in range(len(batch)): + inp_batch = batch[inp_idx] + + if last_iter_remain_samples is not None and len(last_iter_remain_samples[inp_idx]): + # some samples remained from last batch of last computation iteration - + # include them in the current batch + inp_batch = np.concatenate((inp_batch, last_iter_remain_samples[inp_idx])) + + # Compute number of missing samples to get to the requested amount from the current batch + num_missing = min(num_hessian_samples - len(all_inp_hessian_samples[inp_idx]), inp_batch.shape[0]) + # Append each sample separately + samples = [s for s in inp_batch[0:num_missing, ...]] + remaining_samples = [s for s in inp_batch[num_missing:, ...]] + + all_inp_hessian_samples[inp_idx] += [sample.reshape(1, *sample.shape) for sample in samples] + # This list would can only get filled on the last batch iteration + all_inp_remaining_samples[inp_idx] += (remaining_samples) + + if len(all_inp_hessian_samples[0]) > num_hessian_samples: + Logger.critical(f"Requested {num_hessian_samples} samples for computing Hessian approximation but " + f"{len(all_inp_hessian_samples[0])} were collected.") # pragma: no cover + elif len(all_inp_hessian_samples[0]) == num_hessian_samples: + # Collected enough samples, constructing a dataset with the requested batch size + hessian_samples_for_input = [] + for inp_samples in all_inp_hessian_samples: + inp_samples = np.concatenate(inp_samples, axis=0) + num_collected_samples = inp_samples.shape[0] + inp_samples = np.split(inp_samples, + num_collected_samples // min(num_collected_samples, num_hessian_samples)) + hessian_samples_for_input.append(inp_samples[0]) + + return hessian_samples_for_input, all_inp_remaining_samples + Logger.critical( + f"Not enough samples in the provided representative dataset to compute Hessian approximation on " + f"{num_hessian_samples} samples.") def _clear_saved_hessian_info(self): """Clears the saved info approximations.""" self.trace_hessian_request_to_score_list={} - def count_saved_info_of_request(self, hessian_request:TraceHessianRequest) -> int: + def count_saved_info_of_request(self, hessian_request: TraceHessianRequest) -> Dict: """ Counts the saved approximations of Hessian info (traces, for now) for a specific request. If some approximations were computed for this request before, the amount of approximations (per image) @@ -92,26 +143,37 @@ def count_saved_info_of_request(self, hessian_request:TraceHessianRequest) -> in Returns: Number of saved approximations for the given request. """ - # Replace request of a reused target node with a request of the 'reuse group'. - if hessian_request.target_node.reuse_group: - hessian_request = self._get_request_of_reuse_group(hessian_request) - # Check if the request is in the saved info and return its count, otherwise return 0 - return len(self.trace_hessian_request_to_score_list.get(hessian_request, [])) + per_node_counter = {} + + for n in hessian_request.target_nodes: + if n.reuse: + # Reused nodes supposed to have been replaced with a reuse_group + # representing node before calling this method. + Logger.critical(f"Expecting the Hessian request to include only non-reused nodes at this point, " + f"but found node {n.name} with 'reuse' status.") + # Check if the request for this node is in the saved info and store its count, otherwise store 0 + per_node_counter[n] = len(self.trace_hessian_request_to_score_list.get(hessian_request, [])) + return per_node_counter - def compute(self, trace_hessian_request:TraceHessianRequest): + def compute(self, trace_hessian_request: TraceHessianRequest, representative_dataset_gen, num_hessian_samples: int, + last_iter_remain_samples: List[List[np.ndarray]] = None): """ Computes an approximation of the trace of the Hessian based on the provided request configuration and stores it in the cache. Args: trace_hessian_request: Configuration for which to compute the approximation. + representative_dataset_gen: A callable that provides a dataset for sampling. + num_hessian_samples: Number of requested samples to compute batch Hessian approximation scores. + last_iter_remain_samples: A list of input samples (for each input layer) with remaining samples from + previous iterations. """ - Logger.debug(f"Computing Hessian-trace approximation for a node {trace_hessian_request.target_node}.") + Logger.debug(f"Computing Hessian-trace approximation for nodes {trace_hessian_request.target_nodes}.") - # Sample images for the computation - images = self.representative_dataset() + images, next_iter_remain_samples = representative_dataset_gen(num_hessian_samples=num_hessian_samples, + last_iter_remain_samples=last_iter_remain_samples) # Get the framework-specific calculator for trace Hessian approximation fw_hessian_calculator = self.fw_impl.get_trace_hessian_calculator(graph=self.graph, @@ -119,20 +181,42 @@ def compute(self, trace_hessian_request:TraceHessianRequest): trace_hessian_request=trace_hessian_request, num_iterations_for_approximation=self.num_iterations_for_approximation) - # Compute the approximation trace_hessian = fw_hessian_calculator.compute() # Store the computed approximation in the saved info - if trace_hessian_request in self.trace_hessian_request_to_score_list: - self.trace_hessian_request_to_score_list[trace_hessian_request].append(trace_hessian) - else: - self.trace_hessian_request_to_score_list[trace_hessian_request] = [trace_hessian] - - + topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()] + sorted_target_nodes = sorted(trace_hessian_request.target_nodes, + key=lambda x: topo_sorted_nodes_names.index(x.name)) + + for node, hessian in zip(sorted_target_nodes, trace_hessian): + single_node_request = self._construct_single_node_request(trace_hessian_request.mode, + trace_hessian_request.granularity, + node) + + # The hessian for each node is expected to be a tensor where the first axis represents the number of + # images in the batch on which the approximation was computed. + # We collect the results as a list of a result for images, which is combined across batches. + # After conversion, trace_hessian_request_to_score_list for a request of a single node should be a list of + # results of all images, where each result is a tensor of the shape depending on the granularity. + if single_node_request in self.trace_hessian_request_to_score_list: + self.trace_hessian_request_to_score_list[single_node_request] += ( + self._convert_tensor_to_list_of_appx_results(hessian)) + else: + self.trace_hessian_request_to_score_list[single_node_request] = ( + self._convert_tensor_to_list_of_appx_results(hessian)) + + # In case that we are required to return a number of scores that is larger that the computation batch size + # and if in this case the computation batch size is smaller than the representative dataset batch size + # we need to carry over remaining samples from the last fetched batch to the next computation, otherwise, + # we might skip samples or remain without enough samples to complete the computations for the + # requested number of scores. + return next_iter_remain_samples if next_iter_remain_samples is not None and len(next_iter_remain_samples) > 0 \ + and len(next_iter_remain_samples[0]) > 0 else None def fetch_hessian(self, - trace_hessian_request: - TraceHessianRequest, required_size: int) -> List[List[float]]: + trace_hessian_request: TraceHessianRequest, + required_size: int, + batch_size: int = 1) -> List[List[np.ndarray]]: """ Fetches the computed approximations of the trace of the Hessian for the given request and required size. @@ -140,67 +224,152 @@ def fetch_hessian(self, Args: trace_hessian_request: Configuration for which to fetch the approximation. required_size: Number of approximations required. + batch_size: The Hessian computation batch size. Returns: - List[List[float]]: List of computed approximations. + List[List[np.ndarray]]: For each target node, returns a list of computed approximations. The outer list is per image (thus, has the length as required_size). The inner list length dependent on the granularity (1 for per-tensor, OC for per-output-channel when the requested node has OC output-channels, etc.) """ - if required_size==0: - return [] + if required_size == 0: + return [[] for _ in trace_hessian_request.target_nodes] - Logger.info(f"\nEnsuring {required_size} Hessian-trace approximation for node {trace_hessian_request.target_node}.") + Logger.info(f"\nEnsuring {required_size} Hessian-trace approximation for nodes " + f"{trace_hessian_request.target_nodes}.") - # Replace request of a reused target node with a request of the 'reuse group'. - if trace_hessian_request.target_node.reuse_group: - trace_hessian_request = self._get_request_of_reuse_group(trace_hessian_request) + # Replace node in reused target nodes with a representing node from the 'reuse group'. + for n in trace_hessian_request.target_nodes: + if n.reuse_group: + rep_node = self._get_representing_of_reuse_group(n) + trace_hessian_request.target_nodes.remove(n) + if rep_node not in trace_hessian_request.target_nodes: + trace_hessian_request.target_nodes.append(rep_node) # Ensure the saved info has the required number of approximations - self._populate_saved_info_to_size(trace_hessian_request, required_size) + self._populate_saved_info_to_size(trace_hessian_request, required_size, batch_size) # Return the saved approximations for the given request - return self.trace_hessian_request_to_score_list[trace_hessian_request] + return self._collect_saved_hessians_for_request(trace_hessian_request, required_size) - def _get_request_of_reuse_group(self, trace_hessian_request: TraceHessianRequest): + def _get_representing_of_reuse_group(self, node) -> Any: """ For each reused group we compute and fetch its members using a single request. This method creates and returns a request for the reused group the node is in. Args: - trace_hessian_request: Request to fetch its node's reused group request. + node: The node to get its reuse group representative node. - Returns: - TraceHessianRequest for all nodes in the reused group. + Returns: A reuse group representative node (BaseNode). """ - father_nodes = [n for n in self.graph.nodes if not n.reuse and n.reuse_group==trace_hessian_request.target_node.reuse_group] - if len(father_nodes)!=1: - Logger.critical(f"Expected a single non-reused node in the reused group, but found {len(father_nodes)}.") - reused_group_request = TraceHessianRequest(target_node=father_nodes[0], - granularity=trace_hessian_request.granularity, - mode=trace_hessian_request.mode) - return reused_group_request + father_nodes = [n for n in self.graph.nodes if not n.reuse and n.reuse_group == node.reuse_group] + if len(father_nodes) != 1: # pragma: no cover + Logger.critical(f"Expected a single non-reused node in the reused group, " + f"but found {len(father_nodes)}.") + return father_nodes[0] def _populate_saved_info_to_size(self, trace_hessian_request: TraceHessianRequest, - required_size: int): + required_size: int, + batch_size: int = 1): """ Ensures that the saved info has the required size of trace Hessian approximations for the given request. Args: trace_hessian_request: Configuration for which to ensure the saved info size. required_size: Required number of trace Hessian approximations. + batch_size: The Hessian computation batch size. """ - # Get the current number of saved approximations for the request + + # Get the current number of saved approximations for each node in the request current_existing_hessians = self.count_saved_info_of_request(trace_hessian_request) + # Compute the required number of approximations to meet the required size. + # Since we allow batch and multi-nodes computation, we take the node with the maximal number of missing + # approximations to compute, and run batch computations until meeting the requirement. + min_exist_hessians = min(current_existing_hessians.values()) + max_remaining_hessians = required_size - min_exist_hessians + Logger.info( - f"Found {current_existing_hessians} Hessian-trace approximations for node {trace_hessian_request.target_node}." - f" {required_size - current_existing_hessians} approximations left to compute...") + f"Running Hessian approximation computation for {len(trace_hessian_request.target_nodes)} nodes.\n " + f"The node with minimal existing Hessian-trace approximations has {min_exist_hessians} " + f"approximations computed.\n" + f"{max_remaining_hessians} approximations left to compute...") + + hessian_representative_dataset = partial(self._sample_batch_representative_dataset, + num_inputs=len(self.graph.input_nodes), + representative_dataset=self.representative_dataset_gen()) + + next_iter_remaining_samples = None + while max_remaining_hessians > 0: + # If batch_size < max_remaining_hessians then we run each computation on a batch_size of images. + # This way, we always run a computation for a single batch. + size_to_compute = min(max_remaining_hessians, batch_size) + next_iter_remaining_samples = ( + self.compute(trace_hessian_request, hessian_representative_dataset, size_to_compute, + last_iter_remain_samples=next_iter_remaining_samples)) + max_remaining_hessians -= size_to_compute + + def _collect_saved_hessians_for_request(self, trace_hessian_request: TraceHessianRequest, required_size: int + ) -> List[List[np.ndarray]]: + """ + Collects Hessian approximation for the nodes in the given request. + + Args: + trace_hessian_request: Configuration for which to fetch the approximation. + required_size: Required number of trace Hessian approximations. + + Returns: A list with List of computed Hessian approximation (a tensor for each score) for each node + in the request. + + """ + collected_results = [] + for node in trace_hessian_request.target_nodes: + single_node_request = self._construct_single_node_request(trace_hessian_request.mode, + trace_hessian_request.granularity, + node) - # Compute the required number of approximations to meet the required size - for _ in range(required_size - current_existing_hessians): - self.compute(trace_hessian_request) + res_for_node = self.trace_hessian_request_to_score_list.get(single_node_request) + if res_for_node is None: # pragma: no cover + Logger.critical(f"Couldn't find saved Hessian approximations for node {node.name}.") + if len(res_for_node) < required_size: # pragma: no cover + Logger.critical(f"Missing Hessian approximations for node {node.name}, requested {required_size} " + f"but found only {len(res_for_node)}.") + res_for_node = res_for_node[:required_size] + collected_results.append(res_for_node) + + return collected_results + + @staticmethod + def _construct_single_node_request(mode: HessianMode, granularity: HessianInfoGranularity, target_nodes: List + ) -> TraceHessianRequest: + """ + Constructs a Hessian request with for a single node. Used for retrieving and maintaining cached results. + + Args: + mode (HessianMode): Mode of Hessian's trace approximation (w.r.t weights or activations). + granularity (HessianInfoGranularity): Granularity level for the approximation. + target_nodes (List[BaseNode]): The node in the float graph for which the Hessian's trace approximation is targeted. + + Returns: A TraceHessianRequest with the given details for the requested node. + + """ + return TraceHessianRequest(mode, + granularity, + target_nodes=[target_nodes]) + + @staticmethod + def _convert_tensor_to_list_of_appx_results(t: Any) -> List: + """ + Converts a tensor with batch computation results to a list of individual result for each sample in batch. + + Args: + t: A tensor with Hessian approximation results. + + Returns: A list with split batch into individual results. + + """ + return [t[i] for i in range(t.shape[0])] diff --git a/model_compression_toolkit/core/common/hessian/hessian_info_utils.py b/model_compression_toolkit/core/common/hessian/hessian_info_utils.py index e8a75f094..4d0e5d3c8 100644 --- a/model_compression_toolkit/core/common/hessian/hessian_info_utils.py +++ b/model_compression_toolkit/core/common/hessian/hessian_info_utils.py @@ -17,18 +17,19 @@ from model_compression_toolkit.constants import EPS -def normalize_scores(hessian_approximations: List) -> np.ndarray: +def normalize_scores(hessian_approximations: List) -> List[np.ndarray]: """ Normalize Hessian information approximations by dividing the trace Hessian approximations value by the sum of all other values. Args: - hessian_approximations: Approximated average Hessian-based scores for each interest point. + hessian_approximations: Approximated Hessian-based scores for each image for each interest point. Returns: Normalized list of Hessian info approximations for each interest point. """ - scores_vec = np.asarray(hessian_approximations) + scores_vec = np.asarray(hessian_approximations) # Images x Nodes X Scores + norm_scores_per_image = scores_vec / (np.sum(scores_vec, axis=1, keepdims=True) + EPS) - return scores_vec / (np.sum(scores_vec) + EPS) + return [norm_scores_per_image[i] for i in range(norm_scores_per_image.shape[0])] diff --git a/model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py b/model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py index 539f4b14e..f98f805f4 100644 --- a/model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py +++ b/model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py @@ -56,14 +56,9 @@ def __init__(self, self.num_iterations_for_approximation = num_iterations_for_approximation # Validate representative dataset has same inputs as graph - if len(self.input_images)!=len(graph.get_inputs()): + if len(self.input_images) != len(graph.get_inputs()): # pragma: no cover Logger.critical(f"The graph requires {len(graph.get_inputs())} inputs, but the provided representative dataset contains {len(self.input_images)} inputs.") - # Assert all inputs have a batch size of 1 - for image in self.input_images: - if image.shape[0]!=1: - Logger.critical(f"Hessian calculations are restricted to a single-image per input. Found input with shape: {image.shape}.") - self.fw_impl = fw_impl self.hessian_request = trace_hessian_request diff --git a/model_compression_toolkit/core/common/hessian/trace_hessian_request.py b/model_compression_toolkit/core/common/hessian/trace_hessian_request.py index 8db818aab..d5b416d19 100644 --- a/model_compression_toolkit/core/common/hessian/trace_hessian_request.py +++ b/model_compression_toolkit/core/common/hessian/trace_hessian_request.py @@ -52,18 +52,18 @@ class TraceHessianRequest: def __init__(self, mode: HessianMode, granularity: HessianInfoGranularity, - target_node, + target_nodes: List, ): """ Attributes: mode (HessianMode): Mode of Hessian's trace approximation (w.r.t weights or activations). granularity (HessianInfoGranularity): Granularity level for the approximation. - target_node (BaseNode): The node in the float graph for which the Hessian's trace approximation is targeted. + target_nodes (List[BaseNode]): The node in the float graph for which the Hessian's trace approximation is targeted. """ self.mode = mode # w.r.t activations or weights self.granularity = granularity # per element, per layer, per channel - self.target_node = target_node # TODO: extend it list of nodes + self.target_nodes = target_nodes def __eq__(self, other): # Checks if the other object is an instance of TraceHessianRequest @@ -71,9 +71,9 @@ def __eq__(self, other): return isinstance(other, TraceHessianRequest) and \ self.mode == other.mode and \ self.granularity == other.granularity and \ - self.target_node == other.target_node + self.target_nodes == other.target_nodes def __hash__(self): # Computes the hash based on the attributes. # The use of a tuple here ensures that the hash is influenced by all the attributes. - return hash((self.mode, self.granularity, self.target_node)) \ No newline at end of file + return hash((self.mode, self.granularity, tuple(self.target_nodes))) \ No newline at end of file diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py index 6b47b94ec..b07da0850 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py @@ -15,6 +15,7 @@ from typing import List, Callable +from model_compression_toolkit.constants import MP_DEFAULT_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting @@ -23,13 +24,14 @@ class MixedPrecisionQuantizationConfig: def __init__(self, compute_distance_fn: Callable = None, distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG, - num_of_images: int = 32, + num_of_images: int = MP_DEFAULT_NUM_SAMPLES, configuration_overwrite: List[int] = None, num_interest_points_factor: float = 1.0, use_hessian_based_scores: bool = False, norm_scores: bool = True, refine_mp_solution: bool = True, - metric_normalization_threshold: float = 1e10): + metric_normalization_threshold: float = 1e10, + hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE): """ Class with mixed precision parameters to quantize the input model. @@ -43,6 +45,7 @@ def __init__(self, norm_scores (bool): Whether to normalize the returned scores for the weighted distance metric (to get values between 0 and 1). refine_mp_solution (bool): Whether to try to improve the final mixed-precision configuration using a greedy algorithm that searches layers to increase their bit-width, or not. metric_normalization_threshold (float): A threshold for checking the mixed precision distance metric values, In case of values larger than this threshold, the metric will be scaled to prevent numerical issues. + hessian_batch_size (int): The Hessian computation batch size. used only if using mixed precision with Hessian-based objective. """ @@ -60,6 +63,7 @@ def __init__(self, self.use_hessian_based_scores = use_hessian_based_scores self.norm_scores = norm_scores + self.hessian_batch_size = hessian_batch_size self.metric_normalization_threshold = metric_normalization_threshold diff --git a/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py b/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py index 9ce127f07..c7dde44a4 100644 --- a/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +++ b/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py @@ -26,7 +26,6 @@ from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, \ HessianInfoGranularity, HessianInfoService -from model_compression_toolkit.core.common.hessian import hessian_info_utils as hessian_utils class SensitivityEvaluation: @@ -238,47 +237,24 @@ def _compute_hessian_based_scores(self) -> np.ndarray: to be used for the distance metric weighted average computation. """ - # Dictionary to store the trace Hessian approximations for each interest point (target node) - compare_point_to_trace_hessian_approximations = {} - - # Iterate over each interest point to fetch the trace Hessian approximations - for target_node in self.interest_points: - # Create a request for trace Hessian approximation with specific configurations - # (here we use per-tensor approximation of the Hessian's trace w.r.t the node's activations) - trace_hessian_request = TraceHessianRequest(mode=HessianMode.ACTIVATION, - granularity=HessianInfoGranularity.PER_TENSOR, - target_node=target_node) - - # Fetch the trace Hessian approximations for the current interest point - node_approximations = self.hessian_info_service.fetch_hessian(trace_hessian_request=trace_hessian_request, - required_size=self.quant_config.num_of_images) - # Store the fetched approximations in the dictionary - compare_point_to_trace_hessian_approximations[target_node] = node_approximations - - # List to store the approximations for each image - approx_by_image = [] - # Iterate over each image - for image_idx in range(self.quant_config.num_of_images): - # List to store approximations for the current image for each interest point - approx_by_image_per_interest_point = [] - # Iterate over each interest point to gather approximations - for target_node in self.interest_points: - # Ensure the approximation for the current interest point and image is a list - assert isinstance(compare_point_to_trace_hessian_approximations[target_node][image_idx], list) - # Ensure the approximation list contains only one element (since, granularity is per-tensor) - assert len(compare_point_to_trace_hessian_approximations[target_node][image_idx]) == 1 - # Append the single approximation value to the list for the current image - approx_by_image_per_interest_point.append(compare_point_to_trace_hessian_approximations[target_node][image_idx][0]) - - if self.quant_config.norm_scores: - approx_by_image_per_interest_point = \ - hessian_utils.normalize_scores(hessian_approximations=approx_by_image_per_interest_point) - - # Append the approximations for the current image to the main list - approx_by_image.append(approx_by_image_per_interest_point) + # Create a request for trace Hessian approximation with specific configurations + # (here we use per-tensor approximation of the Hessian's trace w.r.t the node's activations) + trace_hessian_request = TraceHessianRequest(mode=HessianMode.ACTIVATION, + granularity=HessianInfoGranularity.PER_TENSOR, + target_nodes=self.interest_points) + + # Fetch the trace Hessian approximations for the current interest point + nodes_approximations = self.hessian_info_service.fetch_hessian(trace_hessian_request=trace_hessian_request, + required_size=self.quant_config.num_of_images, + batch_size=self.quant_config.hessian_batch_size) + + # Store the approximations for each node for each image + approx_by_image = [[nodes_approximations[j][image_idx] + for j, _ in enumerate(self.interest_points)] + for image_idx in range(self.quant_config.num_of_images)] # Return the mean approximation value across all images for each interest point - return np.mean(approx_by_image, axis=0) + return np.mean(np.stack(approx_by_image), axis=0) def _configure_bitwidths_model(self, mp_model_configuration: List[int], diff --git a/model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py b/model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py index 09271f46d..483ce842c 100644 --- a/model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +++ b/model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py @@ -121,7 +121,7 @@ def _get_entry_node_to_score(self, entry_nodes: List[BaseNode]) -> Dict[BaseNode # Initialize HessianInfoService for score computation. hessian_info_service = HessianInfoService(graph=self.float_graph, - representative_dataset=self.representative_data_gen, + representative_dataset_gen=self.representative_data_gen, fw_impl=self.fw_impl) # Fetch and process Hessian scores for output channels of entry nodes. @@ -129,13 +129,13 @@ def _get_entry_node_to_score(self, entry_nodes: List[BaseNode]) -> Dict[BaseNode for node in entry_nodes: _request = TraceHessianRequest(mode=HessianMode.WEIGHTS, granularity=HessianInfoGranularity.PER_OUTPUT_CHANNEL, - target_node=node) + target_nodes=[node]) _scores_for_node = hessian_info_service.fetch_hessian(_request, required_size=self.pruning_config.num_score_approximations) nodes_scores.append(_scores_for_node) # Average and map scores to nodes. - self._entry_node_to_hessian_score = {node: np.mean(scores, axis=0) for node, scores in zip(entry_nodes, nodes_scores)} + self._entry_node_to_hessian_score = {node: np.mean(scores[0], axis=0) for node, scores in zip(entry_nodes, nodes_scores)} self._entry_node_count_oc_nparams = self._count_oc_nparams(entry_nodes=entry_nodes) _entry_node_l2_oc_norm = self._get_squaredl2norm(entry_nodes=entry_nodes) diff --git a/model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py b/model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py index 59fe72c8b..1f1a6b82b 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +++ b/model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py @@ -390,7 +390,7 @@ def _compute_hessian_for_hmse(node, """ _request = TraceHessianRequest(mode=HessianMode.WEIGHTS, granularity=HessianInfoGranularity.PER_ELEMENT, - target_node=node) + target_nodes=[node]) _scores_for_node = hessian_info_service.fetch_hessian(_request, required_size=num_hessian_samples) diff --git a/model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py b/model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py index d8e0a6b94..314cf21f8 100644 --- a/model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py +++ b/model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py @@ -53,22 +53,22 @@ def __init__(self, trace_hessian_request=trace_hessian_request, num_iterations_for_approximation=num_iterations_for_approximation) - def compute(self) -> List[float]: + def compute(self) -> List[np.ndarray]: """ - Compute the approximation of the trace of the Hessian w.r.t a node's activations. + Compute the approximation of the trace of the Hessian w.r.t the requested target nodes' activations. Returns: - List[float]: Approximated trace of the Hessian for an interest point. + List[np.ndarray]: Approximated trace of the Hessian for the requested nodes. """ if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR: model_output_nodes = [ot.node for ot in self.graph.get_outputs()] - if self.hessian_request.target_node in model_output_nodes: + if len([n for n in self.hessian_request.target_nodes if n in model_output_nodes]) > 0: Logger.critical("Trying to compute activation Hessian approximation with respect to the model output. " - "This operation is not supported. " - "Remove the output node from the set of node targets in the Hessian request.") + "This operation is not supported. " + "Remove the output node from the set of node targets in the Hessian request.") - grad_model_outputs = [self.hessian_request.target_node] + model_output_nodes + grad_model_outputs = self.hessian_request.target_nodes + model_output_nodes # Building a model to run Hessian approximation on model, _ = FloatKerasModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model() @@ -82,88 +82,74 @@ def compute(self) -> List[float]: else: outputs = model(*self.input_images) - if len(outputs) != len(grad_model_outputs): + if len(outputs) != len(grad_model_outputs): # pragma: no cover Logger.critical( f"Model for computing activation Hessian approximation expects {len(grad_model_outputs)} " f"outputs, but got {len(outputs)} output tensors.") - # Extracting the intermediate activation tensors and the model real output - # TODO: we assume that the hessian request is for a single node. - # When we extend it to multiple nodes in the same request, then we should modify this part to take - # the first "num_target_nodes" outputs from the output list. - # We also assume that the target nodes are not part of the model output nodes, if this assumption changed, - # then the code should be modified accordingly. - target_activation_tensors = [outputs[0]] - output_tensors = outputs[1:] + # Extracting the intermediate activation tensors and the model real output. + # Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap. + num_target_nodes = len(self.hessian_request.target_nodes) + # Extract activation tensors of nodes for which we want to compute Hessian + target_activation_tensors = outputs[:num_target_nodes] + # Extract the model outputs + output_tensors = outputs[num_target_nodes:] # Unfold and concatenate all outputs to form a single tensor output = self._concat_tensors(output_tensors) # List to store the approximated trace of the Hessian for each interest point - trace_approx_by_node = [] - # Loop through each interest point activation tensor - for ipt in tqdm(target_activation_tensors): # Per Interest point activation tensor - interest_point_scores = [] # List to store scores for each interest point - for j in range(self.num_iterations_for_approximation): # Approximation iterations - # Getting a random vector with normal distribution - v = tf.random.normal(shape=output.shape, dtype=output.dtype) - f_v = tf.reduce_sum(v * output) + ipts_hessian_trace_approx = [tf.Variable([0.0], dtype=tf.float32, trainable=True) + for _ in range(len(target_activation_tensors))] + # Loop through each interest point activation tensor + prev_mean_results = None + for j in tqdm(range(self.num_iterations_for_approximation)): # Approximation iterations + # Getting a random vector with normal distribution + v = tf.random.normal(shape=output.shape, dtype=output.dtype) + f_v = tf.reduce_sum(v * output) + for i, ipt in enumerate(target_activation_tensors): # Per Interest point activation tensor + interest_point_scores = [] # List to store scores for each interest point with g.stop_recording(): # Computing the approximation by getting the gradient of (output * v) - gradients = g.gradient(f_v, ipt, unconnected_gradients=tf.UnconnectedGradients.ZERO) - # If a node has multiple outputs, gradients is a list of tensors. If it has only a single - # output gradients is a tensor. To handle both cases, we first convert gradients to a - # list if it's a single tensor. - if not isinstance(gradients, list): - gradients = [gradients] - - # Compute the approximation per node's output - score_approx_per_output = [] - for grad in gradients: - score_approx_per_output.append(tf.reduce_sum(tf.pow(grad, 2.0))) + hess_v = g.gradient(f_v, ipt) + + if hess_v is None: + # In case we have an output node, which is an interest point, but it is not + # differentiable, we consider its Hessian to be the initial value 0. + continue # pragma: no cover + + # Mean over all dims but the batch (CXHXW for conv) + hessian_trace_approx = tf.reduce_sum(hess_v ** 2.0, + axis=tuple(d for d in range(1, len(hess_v.shape)))) # Free gradients - del grad - del gradients - - # If the change to the mean approximation is insignificant (to all outputs) - # we stop the calculation. - if j > MIN_HESSIAN_ITER: - new_mean_per_output = [] - delta_per_output = [] - # Compute new means and deltas for each output index - for output_idx, score_approx in enumerate(score_approx_per_output): - prev_scores_output = [x[output_idx] for x in interest_point_scores] - new_mean = np.mean([score_approx, *prev_scores_output]) - delta = new_mean - np.mean(prev_scores_output) - new_mean_per_output.append(new_mean) - delta_per_output.append(delta) - - # Check if all outputs have converged - is_converged = all([np.abs(delta) / (np.abs(new_mean) + 1e-6) < HESSIAN_COMP_TOLERANCE for delta, new_mean in zip(delta_per_output, new_mean_per_output)]) - if is_converged: - interest_point_scores.append(score_approx_per_output) - break - - interest_point_scores.append(score_approx_per_output) - - final_approx_per_output = [] - # Compute the final approximation for each output index - num_node_outputs = len(interest_point_scores[0]) - for output_idx in range(num_node_outputs): - final_approx_per_output.append(tf.reduce_mean([x[output_idx] for x in interest_point_scores])) - - # final_approx_per_output is a list of all approximations (one per output), thus we average them to - # get the final score of a node. - trace_approx_by_node.append(tf.reduce_mean(final_approx_per_output)) # Get averaged squared trace approximation - - trace_approx_by_node = tf.reduce_mean([trace_approx_by_node], axis=0) # Just to get one tensor instead of list of tensors with single element - - # Free gradient tape - del g - - return trace_approx_by_node.numpy().tolist() - - else: - Logger.critical(f"{self.hessian_request.granularity} is not supported for Keras activation hessian\'s trace approximation calculator.") + del hess_v + + # Update node Hessian approximation mean over random iterations + ipts_hessian_trace_approx[i] = (j * ipts_hessian_trace_approx[i] + hessian_trace_approx) / (j + 1) + + # If the change to the mean approximation is insignificant (to all outputs) + # we stop the calculation. + if j > MIN_HESSIAN_ITER: + if prev_mean_results is not None: + new_mean_res = tf.reduce_mean(tf.stack(ipts_hessian_trace_approx), axis=1) + relative_delta_per_node = (tf.abs(new_mean_res - prev_mean_results) / + (tf.abs(new_mean_res) + 1e-6)) + max_delta = tf.reduce_max(relative_delta_per_node) + if max_delta < HESSIAN_COMP_TOLERANCE: + break + prev_mean_results = tf.reduce_mean(tf.stack(ipts_hessian_trace_approx), axis=1) + + # Convert results to list of numpy arrays + hessian_results = [h.numpy() for h in ipts_hessian_trace_approx] + # Extend the Hessian tensors shape to align with expected return type + # TODO: currently, only per-tensor Hessian is available for activation. + # Once implementing per-channel or per-element, this alignment needs to be verified and handled separately. + hessian_results = [h[..., np.newaxis] for h in hessian_results] + + return hessian_results + + else: # pragma: no cover + Logger.critical(f"{self.hessian_request.granularity} " + f"is not supported for Keras activation hessian\'s trace approximation calculator.") diff --git a/model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py b/model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py index 5fddbb093..77aba36e7 100644 --- a/model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py +++ b/model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py @@ -46,13 +46,19 @@ def __init__(self, trace_hessian_request: Configuration request for which to compute the trace Hessian approximation. num_iterations_for_approximation: Number of iterations to use when approximating the Hessian trace. """ + + if len(trace_hessian_request.target_nodes) > 1: # pragma: no cover + Logger.critical(f"Weights Hessian approximation is currently supported only for a single target node," + f" but the provided request contains the following target nodes: " + f"{trace_hessian_request.target_nodes}.") + super(WeightsTraceHessianCalculatorKeras, self).__init__(graph=graph, input_images=input_images, fw_impl=fw_impl, trace_hessian_request=trace_hessian_request, num_iterations_for_approximation=num_iterations_for_approximation) - def compute(self) -> np.ndarray: + def compute(self) -> List[np.ndarray]: """ Compute the Hessian-based scores w.r.t target node's weights. Currently, supported nodes are [Conv2D, Dense, Conv2DTranspose, DepthwiseConv2D]. @@ -63,28 +69,32 @@ def compute(self) -> np.ndarray: for HessianInfoGranularity.PER_OUTPUT_CHANNEL the shape will be (2,) and for HessianInfoGranularity.PER_ELEMENT a shape of (3, 3, 3, 2). - Returns: The computed scores as numpy ndarray. + Returns: The computed scores as a list of numpy arrays. + The function returns a list for compatibility reasons. """ - # Check if the target node's layer type is supported - if not DEFAULT_KERAS_INFO.is_kernel_op(self.hessian_request.target_node.type): - Logger.critical( - f"{self.hessian_request.target_node.type} is not supported for Hessian-based scoring with respect to weights.") + # Check if the target node's layer type is supported. + # We assume that weights Hessian computation is done only for a single node at each request. + target_node = self.hessian_request.target_nodes[0] + if not DEFAULT_KERAS_INFO.is_kernel_op(target_node.type): + Logger.critical(f"Hessian information with respect to weights is not supported for " + f"{target_node.type} layers.") # pragma: no cover # Construct the Keras float model for inference model, _ = FloatKerasModelBuilder(graph=self.graph).build_model() # Get the weight attributes for the target node type - weight_attributes = DEFAULT_KERAS_INFO.get_kernel_op_attributes(self.hessian_request.target_node.type) + weight_attributes = DEFAULT_KERAS_INFO.get_kernel_op_attributes(target_node.type) # Get the weight tensor for the target node - if len(weight_attributes) != 1: - Logger.critical(f"Hessian-based scoring with respect to weights is currently supported only for nodes with a single weight attribute. Found {len(weight_attributes)} attributes.") + if len(weight_attributes) != 1: # pragma: no cover + Logger.critical(f"Hessian-based scoring with respect to weights is currently supported only for nodes with " + f"a single weight attribute. Found {len(weight_attributes)} attributes.") - weight_tensor = getattr(model.get_layer(self.hessian_request.target_node.name), weight_attributes[0]) + weight_tensor = getattr(model.get_layer(target_node.name), weight_attributes[0]) # Get the output channel index (needed for HessianInfoGranularity.PER_OUTPUT_CHANNEL case) - output_channel_axis, _ = DEFAULT_KERAS_INFO.kernel_channels_mapping.get(self.hessian_request.target_node.type) + output_channel_axis, _ = DEFAULT_KERAS_INFO.kernel_channels_mapping.get(target_node.type) # Get number of scores that should be calculated by the granularity. num_of_scores = self._get_num_scores_by_granularity(weight_tensor, @@ -138,13 +148,17 @@ def compute(self) -> np.ndarray: del tape if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR: - if final_approx.shape != (1,): + if final_approx.shape != (1,): # pragma: no cover Logger.critical(f"For HessianInfoGranularity.PER_TENSOR, the expected score shape is (1,), but found {final_approx.shape}.") elif self.hessian_request.granularity == HessianInfoGranularity.PER_ELEMENT: # Reshaping the scores to the original weight shape final_approx = tf.reshape(final_approx, weight_tensor.shape) - return final_approx.numpy() + # Add a batch axis to the Hessian approximation tensor (to align with the expected returned shape) + # We assume per-image computation, so the batch axis size is 1. + final_approx = final_approx[np.newaxis, ...] + + return [final_approx.numpy()] def _reshape_gradients(self, gradients: tf.Tensor, @@ -193,5 +207,5 @@ def _get_num_scores_by_granularity(self, return weight_tensor.shape[output_channel_axis] elif self.hessian_request.granularity == HessianInfoGranularity.PER_ELEMENT: return tf.size(weight_tensor).numpy() - else: + else: # pragma: no cover Logger.critical(f"Unexpected granularity encountered: {self.hessian_request.granularity}.") diff --git a/model_compression_toolkit/core/pytorch/hessian/activation_trace_hessian_calculator_pytorch.py b/model_compression_toolkit/core/pytorch/hessian/activation_trace_hessian_calculator_pytorch.py index 88457c81c..4695fb438 100644 --- a/model_compression_toolkit/core/pytorch/hessian/activation_trace_hessian_calculator_pytorch.py +++ b/model_compression_toolkit/core/pytorch/hessian/activation_trace_hessian_calculator_pytorch.py @@ -17,6 +17,7 @@ from torch import autograd from tqdm import tqdm +import numpy as np from model_compression_toolkit.constants import MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, HESSIAN_NUM_ITERATIONS from model_compression_toolkit.core.common import Graph @@ -28,6 +29,7 @@ from model_compression_toolkit.logger import Logger import torch + class ActivationTraceHessianCalculatorPytorch(TraceHessianCalculatorPytorch): """ Pytorch implementation of the Trace Hessian approximation Calculator for activations. @@ -53,20 +55,22 @@ def __init__(self, trace_hessian_request=trace_hessian_request, num_iterations_for_approximation=num_iterations_for_approximation) - def compute(self) -> List[float]: + def compute(self) -> List[np.ndarray]: """ - Compute the approximation of the trace of the Hessian w.r.t a node's activations. + Compute the approximation of the trace of the Hessian w.r.t the requested target nodes' activations. Returns: - List[float]: Approximated trace of the Hessian for an interest point. + List[np.ndarray]: Approximated trace of the Hessian for the requested nodes. """ if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR: model_output_nodes = [ot.node for ot in self.graph.get_outputs()] - if self.hessian_request.target_node in model_output_nodes: - Logger.critical("Activation Hessian approximation cannot be computed for model outputs. Exclude output nodes from Hessian request targets.") - grad_model_outputs = [self.hessian_request.target_node] + model_output_nodes + if len([n for n in self.hessian_request.target_nodes if n in model_output_nodes]) > 0: + Logger.critical("Activation Hessian approximation cannot be computed for model outputs. " + "Exclude output nodes from Hessian request targets.") + + grad_model_outputs = self.hessian_request.target_nodes + model_output_nodes model, _ = FloatPyTorchModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model() model.eval() @@ -78,69 +82,71 @@ def compute(self) -> List[float]: outputs = model(*self.input_images) - if len(outputs) != len(grad_model_outputs): - Logger.critical(f"Mismatch in expected and actual model outputs for activation Hessian approximation. Expected {len(grad_model_outputs)} outputs, received {len(outputs)}.") - - # Extracting the intermediate activation tensors and the model real output - # TODO: we are assuming that the hessian request is for a single node. - # When we extend it to multiple nodes in the same request, then we should modify this part to take - # the first "num_target_nodes" outputs from the output list. - # We also assume that the target nodes are not part of the model output nodes, if this assumption changed, - # then the code should be modified accordingly. - target_activation_tensors = [outputs[0]] - output_tensors = outputs[1:] + if len(outputs) != len(grad_model_outputs): # pragma: no cover + Logger.critical(f"Mismatch in expected and actual model outputs for activation Hessian approximation. " + f"Expected {len(grad_model_outputs)} outputs, received {len(outputs)}.") + + # Extracting the intermediate activation tensors and the model real output. + # Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap. + num_target_nodes = len(self.hessian_request.target_nodes) + # Extract activation tensors of nodes for which we want to compute Hessian + target_activation_tensors = outputs[:num_target_nodes] + # Extract the model outputs + output_tensors = outputs[num_target_nodes:] device = output_tensors[0].device # Concat outputs # First, we need to unfold all outputs that are given as list, to extract the actual output tensors output = self.concat_tensors(output_tensors) - ipts_hessian_trace_approx = [] - for ipt_tensor in tqdm(target_activation_tensors): # Per Interest point activation tensor - trace_hv = [] - for j in range(self.num_iterations_for_approximation): # Approximation iterations - # Getting a random vector with normal distribution - v = torch.randn(output.shape, device=device) - f_v = torch.sum(v * output) - + ipts_hessian_trace_approx = [torch.tensor([0.0], + requires_grad=True, + device=device) + for _ in range(len(target_activation_tensors))] + prev_mean_results = None + for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations + # Getting a random vector with normal distribution + v = torch.randn(output.shape, device=device) + f_v = torch.sum(v * output) + for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor # Computing the hessian trace approximation by getting the gradient of (output * v) hess_v = autograd.grad(outputs=f_v, - inputs=ipt_tensor, - retain_graph=True, - allow_unused=True)[0] + inputs=ipt_tensor, + retain_graph=True, + allow_unused=True)[0] + if hess_v is None: # In case we have an output node, which is an interest point, but it is not differentiable, - # we still want to set some weight for it. For this, we need to add this dummy tensor to the ipt - # Hessian traces list. - trace_hv.append(torch.tensor([0.0], - requires_grad=True, - device=device)) - break - hessian_trace_approx = torch.sum(torch.pow(hess_v, 2.0)) - - # If the change to the mean Hessian approximation is insignificant we stop the calculation - if j > MIN_HESSIAN_ITER: - new_mean = torch.mean(torch.stack([hessian_trace_approx, *trace_hv])) - delta = new_mean - torch.mean(torch.stack(trace_hv)) - if torch.abs(delta) / (torch.abs(new_mean) + 1e-6) < HESSIAN_COMP_TOLERANCE: - trace_hv.append(hessian_trace_approx) + # we consider its Hessian to be the initial value 0. + continue # pragma: no cover + + # Mean over all dims but the batch (CXHXW for conv) + hessian_trace_approx = torch.sum(hess_v ** 2.0, dim=tuple(d for d in range(1, len(hess_v.shape)))) + + # Update node Hessian approximation mean over random iterations + ipts_hessian_trace_approx[i] = (j * ipts_hessian_trace_approx[i] + hessian_trace_approx) / (j + 1) + + # If the change to the maximal mean Hessian approximation is insignificant we stop the calculation + if j > MIN_HESSIAN_ITER: + if prev_mean_results is not None: + new_mean_res = torch.mean(torch.stack(ipts_hessian_trace_approx), dim=1) + relative_delta_per_node = (torch.abs(new_mean_res - prev_mean_results) / + (torch.abs(new_mean_res) + 1e-6)) + max_delta = torch.max(relative_delta_per_node) + if max_delta < HESSIAN_COMP_TOLERANCE: break + prev_mean_results = torch.mean(torch.stack(ipts_hessian_trace_approx), dim=1) - trace_hv.append(hessian_trace_approx) - - ipts_hessian_trace_approx.append(torch.mean(torch.stack(trace_hv))) # Get averaged Hessian trace approximation - - # If a node has multiple outputs, it means that multiple approximations were computed - # (one per output since granularity is per-tensor). In this case we average the approximations. - if len(ipts_hessian_trace_approx) > 1: - # Stack tensors and compute the average - ipts_hessian_trace_approx = [torch.stack(ipts_hessian_trace_approx).mean()] - - ipts_hessian_trace_approx = torch_tensor_to_numpy(torch.Tensor( - ipts_hessian_trace_approx)) # Just to get one tensor instead of list of tensors with single element + # Convert results to list of numpy arrays + hessian_results = [torch_tensor_to_numpy(h) for h in ipts_hessian_trace_approx] + # Extend the Hessian tensors shape to align with expected return type + # TODO: currently, only per-tensor Hessian is available for activation. + # Once implementing per-channel or per-element, this alignment needs to be verified and handled separately. + hessian_results = [h[..., np.newaxis] for h in hessian_results] - return ipts_hessian_trace_approx.tolist() + return hessian_results - else: - Logger.critical(f"PyTorch activation Hessian's trace approximation does not support {self.hessian_request.granularity} granularity.") + else: # pragma: no cover + Logger.critical(f"PyTorch activation Hessian's trace approximation does not support " + f"{self.hessian_request.granularity} granularity.") diff --git a/model_compression_toolkit/core/pytorch/hessian/weights_trace_hessian_calculator_pytorch.py b/model_compression_toolkit/core/pytorch/hessian/weights_trace_hessian_calculator_pytorch.py index 82117589a..feaf02577 100644 --- a/model_compression_toolkit/core/pytorch/hessian/weights_trace_hessian_calculator_pytorch.py +++ b/model_compression_toolkit/core/pytorch/hessian/weights_trace_hessian_calculator_pytorch.py @@ -47,14 +47,19 @@ def __init__(self, trace_hessian_request: Configuration request for which to compute the trace Hessian approximation. num_iterations_for_approximation: Number of iterations to use when approximating the Hessian trace. """ + + if len(trace_hessian_request.target_nodes) > 1: # pragma: no cover + Logger.critical(f"Weights Hessian approximation is currently supported only for a single target node," + f" but the provided request contains the following target nodes: " + f"{trace_hessian_request.target_nodes}.") + super(WeightsTraceHessianCalculatorPytorch, self).__init__(graph=graph, input_images=input_images, fw_impl=fw_impl, trace_hessian_request=trace_hessian_request, num_iterations_for_approximation=num_iterations_for_approximation) - - def compute(self) -> np.ndarray: + def compute(self) -> List[np.ndarray]: """ Compute the Hessian-based scores w.r.t target node's weights. The computed scores are returned in a numpy array. The shape of the result differs @@ -65,27 +70,32 @@ def compute(self) -> np.ndarray: HessianInfoGranularity.PER_ELEMENT a shape of (2, 3, 3, 3). Returns: - The computed scores as numpy ndarray for target node's weights. + The computed scores as a list of numpy ndarray for target node's weights. + The function returns a list for compatibility reasons. """ - # Check if the target node's layer type is supported - if not DEFAULT_PYTORCH_INFO.is_kernel_op(self.hessian_request.target_node.type): - Logger.critical(f"Hessian information with respect to weights is not supported for {self.hessian_request.target_node.type} layers.") # pragma: no cover + # Check if the target node's layer type is supported. + # We assume that weights Hessian computation is done only for a single node at each request. + target_node = self.hessian_request.target_nodes[0] + if not DEFAULT_PYTORCH_INFO.is_kernel_op(target_node.type): + Logger.critical(f"Hessian information with respect to weights is not supported for " + f"{target_node.type} layers.") # pragma: no cover # Float model model, _ = FloatPyTorchModelBuilder(graph=self.graph).build_model() # Get the weight attributes for the target node type - weights_attributes = DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(self.hessian_request.target_node.type) + weights_attributes = DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(target_node.type) # Get the weight tensor for the target node - if len(weights_attributes) != 1: - Logger.critical(f"Currently, Hessian scores with respect to weights are supported only for nodes with a single weight attribute. {len(weights_attributes)} attributes found.") + if len(weights_attributes) != 1: # pragma: no cover + Logger.critical(f"Currently, Hessian scores with respect to weights are supported only for nodes with a " + f"single weight attribute. {len(weights_attributes)} attributes found.") - weights_tensor = getattr(getattr(model,self.hessian_request.target_node.name),weights_attributes[0]) + weights_tensor = getattr(getattr(model, target_node.name), weights_attributes[0]) # Get the output channel index - output_channel_axis, _ = DEFAULT_PYTORCH_INFO.kernel_channels_mapping.get(self.hessian_request.target_node.type) + output_channel_axis, _ = DEFAULT_PYTORCH_INFO.kernel_channels_mapping.get(target_node.type) shape_channel_axis = [i for i in range(len(weights_tensor.shape))] if self.hessian_request.granularity == HessianInfoGranularity.PER_OUTPUT_CHANNEL: shape_channel_axis.remove(output_channel_axis) @@ -128,5 +138,9 @@ def compute(self) -> np.ndarray: if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR: final_approx = final_approx.reshape(1) - return final_approx.detach().cpu().numpy() + # Add a batch axis to the Hessian approximation tensor (to align with the expected returned shape). + # We assume per-image computation, so the batch axis size is 1. + final_approx = final_approx[np.newaxis, ...] + + return [final_approx.detach().cpu().numpy()] diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py index 780ee40e9..99a4303ca 100644 --- a/model_compression_toolkit/core/runner.py +++ b/model_compression_toolkit/core/runner.py @@ -112,8 +112,7 @@ def core_runner(in_model: Any, mixed_precision_enable=core_config.mixed_precision_enable, running_gptq=running_gptq) - hessian_info_service = HessianInfoService(graph=graph, - representative_dataset=representative_data_gen, + hessian_info_service = HessianInfoService(graph=graph, representative_dataset_gen=representative_data_gen, fw_impl=fw_impl) tg = quantization_preparation_runner(graph=graph, diff --git a/model_compression_toolkit/gptq/common/gptq_config.py b/model_compression_toolkit/gptq/common/gptq_config.py index 071201f4e..b15eb8eb1 100644 --- a/model_compression_toolkit/gptq/common/gptq_config.py +++ b/model_compression_toolkit/gptq/common/gptq_config.py @@ -14,6 +14,8 @@ # ============================================================================== from enum import Enum from typing import Callable, Any, Dict + +from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT @@ -36,10 +38,11 @@ class GPTQHessianScoresConfig: """ def __init__(self, - hessians_num_samples: int = 16, + hessians_num_samples: int = GPTQ_HESSIAN_NUM_SAMPLES, norm_scores: bool = True, log_norm: bool = True, - scale_log_norm: bool = False): + scale_log_norm: bool = False, + hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE): """ Initialize a GPTQHessianWeightsConfig. @@ -49,12 +52,14 @@ def __init__(self, norm_scores (bool): Whether to normalize the returned scores of the weighted loss function (to get values between 0 and 1). log_norm (bool): Whether to use log normalization for the GPTQ Hessian-based scores. scale_log_norm (bool): Whether to scale the final vector of the Hessian-based scores. + hessian_batch_size (int): The Hessian computation batch size. used only if using GPTQ with Hessian-based objective. """ self.hessians_num_samples = hessians_num_samples self.norm_scores = norm_scores self.log_norm = log_norm self.scale_log_norm = scale_log_norm + self.hessian_batch_size = hessian_batch_size class GradientPTQConfig: diff --git a/model_compression_toolkit/gptq/common/gptq_training.py b/model_compression_toolkit/gptq/common/gptq_training.py index 78ddeaee1..a172c8569 100644 --- a/model_compression_toolkit/gptq/common/gptq_training.py +++ b/model_compression_toolkit/gptq/common/gptq_training.py @@ -17,6 +17,7 @@ import numpy as np from typing import Callable, List, Any, Dict +from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig from model_compression_toolkit.core.common import Graph, BaseNode from model_compression_toolkit.core.common.framework_info import FrameworkInfo @@ -144,15 +145,19 @@ def compute_hessian_based_weights(self) -> np.ndarray: # Fetch hessian approximations for each target node compare_point_to_trace_hessian_approximations = self._fetch_hessian_approximations() # Process the fetched hessian approximations to gather them per images - trace_hessian_approx_by_image = self._process_hessian_approximations(compare_point_to_trace_hessian_approximations) + trace_hessian_approx_by_image = ( + self._process_hessian_approximations(compare_point_to_trace_hessian_approximations)) # Check if log normalization is enabled in the configuration if self.gptq_config.hessian_weights_config.log_norm: # Calculate the mean of the approximations across images mean_approx_scores = np.mean(trace_hessian_approx_by_image, axis=0) + # Reduce unnecessary dims, should remain with one dimension for the number of nodes + mean_approx_scores = np.squeeze(mean_approx_scores) # Handle zero values to avoid log(0) mean_approx_scores = np.where(mean_approx_scores != 0, mean_approx_scores, np.partition(mean_approx_scores, 1)[1]) + # Calculate log weights log_weights = np.log10(mean_approx_scores) @@ -167,7 +172,6 @@ def compute_hessian_based_weights(self) -> np.ndarray: # If log normalization is not enabled, return the mean of the approximations across images return np.mean(trace_hessian_approx_by_image, axis=0) - def _fetch_hessian_approximations(self) -> Dict[BaseNode, List[List[float]]]: """ Fetches hessian approximations for each target node. @@ -176,17 +180,20 @@ def _fetch_hessian_approximations(self) -> Dict[BaseNode, List[List[float]]]: Mapping of target nodes to their hessian approximations. """ approximations = {} - for target_node in self.compare_points: - trace_hessian_request = TraceHessianRequest( - mode=HessianMode.ACTIVATION, - granularity=HessianInfoGranularity.PER_TENSOR, - target_node=target_node - ) - node_approximations = self.hessian_service.fetch_hessian( - trace_hessian_request=trace_hessian_request, - required_size=self.gptq_config.hessian_weights_config.hessians_num_samples - ) - approximations[target_node] = node_approximations + trace_hessian_request = TraceHessianRequest( + mode=HessianMode.ACTIVATION, + granularity=HessianInfoGranularity.PER_TENSOR, + target_nodes=self.compare_points + ) + node_approximations = self.hessian_service.fetch_hessian( + trace_hessian_request=trace_hessian_request, + required_size=self.gptq_config.hessian_weights_config.hessians_num_samples, + batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size + ) + + for i, target_node in enumerate(self.compare_points): + approximations[target_node] = node_approximations[i] + return approximations def _process_hessian_approximations(self, approximations: Dict[BaseNode, List[List[float]]]) -> List: @@ -203,12 +210,13 @@ def _process_hessian_approximations(self, approximations: Dict[BaseNode, List[Li Processed approximations as a list of lists where each inner list is the approximations per image to all interest points. """ - trace_hessian_approx_by_image = [] - for image_idx in range(self.gptq_config.hessian_weights_config.hessians_num_samples): - approx_by_interest_point = self._get_approximations_by_interest_point(approximations, image_idx) - if self.gptq_config.hessian_weights_config.norm_scores: - approx_by_interest_point = hessian_utils.normalize_scores(approx_by_interest_point) - trace_hessian_approx_by_image.append(approx_by_interest_point) + trace_hessian_approx_by_image = [[approximations[target_node][image_idx] for target_node in self.compare_points] + for image_idx in + range(self.gptq_config.hessian_weights_config.hessians_num_samples)] + + if self.gptq_config.hessian_weights_config.norm_scores: + trace_hessian_approx_by_image = hessian_utils.normalize_scores(trace_hessian_approx_by_image) + return trace_hessian_approx_by_image def _get_approximations_by_interest_point(self, approximations: Dict, image_idx: int) -> List: diff --git a/model_compression_toolkit/gptq/keras/quantization_facade.py b/model_compression_toolkit/gptq/keras/quantization_facade.py index c4d48b5a4..47740f734 100644 --- a/model_compression_toolkit/gptq/keras/quantization_facade.py +++ b/model_compression_toolkit/gptq/keras/quantization_facade.py @@ -21,9 +21,9 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT from model_compression_toolkit.logger import Logger -from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF +from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF, ACT_HESSIAN_DEFAULT_BATCH_SIZE from model_compression_toolkit.core.common.user_info import UserInformation -from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig +from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig from model_compression_toolkit.core import CoreConfig @@ -68,7 +68,8 @@ def get_keras_gptq_config(n_epochs: int, loss: Callable = GPTQMultipleTensorsLoss(), log_function: Callable = None, use_hessian_based_weights: bool = True, - regularization_factor: float = REG_DEFAULT) -> GradientPTQConfig: + regularization_factor: float = REG_DEFAULT, + hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE) -> GradientPTQConfig: """ Create a GradientPTQConfigV2 instance for Keras models. @@ -80,6 +81,7 @@ def get_keras_gptq_config(n_epochs: int, log_function (Callable): Function to log information about the gptq process. use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss. regularization_factor (float): A floating point number that defines the regularization factor. + hessian_batch_size (int): Batch size for Hessian computation in Hessian-based weights GPTQ. returns: a GradientPTQConfigV2 object to use when fine-tuning the quantized model using gptq. @@ -112,7 +114,8 @@ def get_keras_gptq_config(n_epochs: int, train_bias=True, optimizer_bias=bias_optimizer, use_hessian_based_weights=use_hessian_based_weights, - regularization_factor=regularization_factor) + regularization_factor=regularization_factor, + hessian_weights_config=GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size)) def keras_gradient_post_training_quantization(in_model: Model, representative_data_gen: Callable, diff --git a/model_compression_toolkit/gptq/pytorch/quantization_facade.py b/model_compression_toolkit/gptq/pytorch/quantization_facade.py index 83e64d7e8..0f0b3f58b 100644 --- a/model_compression_toolkit/gptq/pytorch/quantization_facade.py +++ b/model_compression_toolkit/gptq/pytorch/quantization_facade.py @@ -16,12 +16,12 @@ from typing import Callable from model_compression_toolkit.core import common -from model_compression_toolkit.constants import FOUND_TORCH +from model_compression_toolkit.constants import FOUND_TORCH, ACT_HESSIAN_DEFAULT_BATCH_SIZE from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT from model_compression_toolkit.logger import Logger from model_compression_toolkit.constants import PYTORCH -from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig +from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization from model_compression_toolkit.core.runner import core_runner @@ -57,7 +57,9 @@ def get_pytorch_gptq_config(n_epochs: int, loss: Callable = multiple_tensors_mse_loss, log_function: Callable = None, use_hessian_based_weights: bool = True, - regularization_factor: float = REG_DEFAULT) -> GradientPTQConfig: + regularization_factor: float = REG_DEFAULT, + hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE + ) -> GradientPTQConfig: """ Create a GradientPTQConfigV2 instance for Pytorch models. @@ -69,6 +71,7 @@ def get_pytorch_gptq_config(n_epochs: int, log_function (Callable): Function to log information about the gptq process. use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss. regularization_factor (float): A floating point number that defines the regularization factor. + hessian_batch_size (int): Batch size for Hessian computation in Hessian-based weights GPTQ. returns: a GradientPTQConfigV2 object to use when fine-tuning the quantized model using gptq. @@ -92,7 +95,8 @@ def get_pytorch_gptq_config(n_epochs: int, return GradientPTQConfig(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss, log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer, use_hessian_based_weights=use_hessian_based_weights, - regularization_factor=regularization_factor) + regularization_factor=regularization_factor, + hessian_weights_config=GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size)) def pytorch_gradient_post_training_quantization(model: Module, diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py index 0b3108491..0b7c60acc 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py @@ -19,6 +19,7 @@ import model_compression_toolkit as mct from model_compression_toolkit import DefaultDict +from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GradientPTQConfig, \ GPTQHessianScoresConfig from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod @@ -59,9 +60,10 @@ class GradientPTQBaseTest(BaseKerasFeatureNetworkTest): def __init__(self, unit_test, quant_method=QuantizationMethod.SYMMETRIC, rounding_type=RoundingType.STE, per_channel=True, input_shape=(1, 16, 16, 3), hessian_weights=True, log_norm_weights=True, scaled_log_norm=False, - quantization_parameter_learning=True): + quantization_parameter_learning=True, num_calibration_iter=GPTQ_HESSIAN_NUM_SAMPLES): super().__init__(unit_test, - input_shape=input_shape ) + input_shape=input_shape, + num_calibration_iter=num_calibration_iter) self.quant_method = quant_method self.rounding_type = rounding_type @@ -157,7 +159,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= class GradientPTQWeightsUpdateTest(GradientPTQBaseTest): def get_gptq_config(self): - return GradientPTQConfig(50, optimizer=tf.keras.optimizers.Adam( + return GradientPTQConfig(20, optimizer=tf.keras.optimizers.Adam( learning_rate=1e-2), optimizer_rest=tf.keras.optimizers.Adam( learning_rate=1e-1), loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, gptq_quantizer_params_override=self.override_params) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py index 4543e983c..3942e71d9 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py @@ -99,9 +99,8 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= class MixedPercisionSearchTest(MixedPercisionBaseTest): - def __init__(self, unit_test, distance_metric=MpDistanceWeighting.AVG, expected_mp_config=[0,0]): + def __init__(self, unit_test, distance_metric=MpDistanceWeighting.AVG): super().__init__(unit_test, val_batch_size=2) - self.expected_mp_config = expected_mp_config self.distance_metric = distance_metric def get_resource_utilization(self): @@ -113,7 +112,9 @@ def get_mixed_precision_config(self): def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) - self.unit_test.assertTrue((quantization_info.mixed_precision_cfg == self.expected_mp_config).all()) + self.unit_test.assertTrue(any([b != 0 for b in quantization_info.mixed_precision_cfg]), + "At least one of the conv layers is expected to be quantized to meet the required " + "resource utilization target.") for i in range(32): # quantized per channel self.unit_test.assertTrue( np.unique(conv_layers[0].get_quantized_weights()['kernel'][:, :, :, i]).flatten().shape[0] <= 256) diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index fd4bdb9ad..e5f442365 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -216,8 +216,8 @@ def test_mixed_precision_search_4bits_avg_nms(self): MixedPercisionCombinedNMSTest(self).run_test() def test_mixed_precision_search(self): - MixedPercisionSearchTest(self, distance_metric=MpDistanceWeighting.AVG, expected_mp_config=[0, 1]).run_test() - MixedPercisionSearchTest(self, distance_metric=MpDistanceWeighting.LAST_LAYER, expected_mp_config=[1, 0]).run_test() + MixedPercisionSearchTest(self, distance_metric=MpDistanceWeighting.AVG).run_test() + MixedPercisionSearchTest(self, distance_metric=MpDistanceWeighting.LAST_LAYER).run_test() def test_requires_mixed_recision(self): RequiresMixedPrecisionWeights(self, weights_memory=True).run_test() diff --git a/tests/keras_tests/function_tests/test_get_gptq_config.py b/tests/keras_tests/function_tests/test_get_gptq_config.py index 7ea3e7dad..552c3730d 100644 --- a/tests/keras_tests/function_tests/test_get_gptq_config.py +++ b/tests/keras_tests/function_tests/test_get_gptq_config.py @@ -54,11 +54,8 @@ def build_model(in_input_shape: List[int]) -> tf.keras.Model: def random_datagen(): - return [np.random.random(SHAPE)] - - -def random_datagen_experimental(): - yield [np.random.random(SHAPE)] + for _ in range(20): + yield [np.random.random(SHAPE)] class TestGetGPTQConfig(unittest.TestCase): @@ -72,45 +69,46 @@ def setUp(self): log_norm=True, scale_log_norm=True) - self.gptqv2_configurations = [GradientPTQConfig(1, optimizer=tf.keras.optimizers.RMSprop(), - optimizer_rest=tf.keras.optimizers.RMSprop(), - train_bias=True, - loss=multiple_tensors_mse_loss, - rounding_type=RoundingType.SoftQuantizer), - GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam(), - optimizer_rest=tf.keras.optimizers.Adam(), - train_bias=True, - loss=multiple_tensors_mse_loss, - rounding_type=RoundingType.SoftQuantizer), - GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam(), - optimizer_rest=tf.keras.optimizers.Adam(), - train_bias=True, - loss=multiple_tensors_mse_loss, - rounding_type=RoundingType.SoftQuantizer, - regularization_factor=15), - GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam(), - optimizer_rest=tf.keras.optimizers.Adam(), - train_bias=True, - loss=multiple_tensors_mse_loss, - rounding_type=RoundingType.SoftQuantizer, - gptq_quantizer_params_override={QUANT_PARAM_LEARNING_STR: True}), - GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam(), - optimizer_rest=tf.keras.optimizers.Adam(), - train_bias=True, - loss=multiple_tensors_mse_loss, - rounding_type=RoundingType.SoftQuantizer, - hessian_weights_config=test_hessian_weights_config), - GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam(), - optimizer_rest=tf.keras.optimizers.Adam(), - train_bias=True, - loss=multiple_tensors_mse_loss, - rounding_type=RoundingType.STE, - gptq_quantizer_params_override={MAX_LSB_STR: DefaultDict(default_value=1)}), - get_keras_gptq_config(n_epochs=1, - optimizer=tf.keras.optimizers.Adam()), - get_keras_gptq_config(n_epochs=1, - optimizer=tf.keras.optimizers.Adam(), - regularization_factor=0.001)] + self.gptq_configurations = [GradientPTQConfig(1, optimizer=tf.keras.optimizers.RMSprop(), + optimizer_rest=tf.keras.optimizers.RMSprop(), + train_bias=True, + loss=multiple_tensors_mse_loss, + rounding_type=RoundingType.SoftQuantizer), + GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam(), + optimizer_rest=tf.keras.optimizers.Adam(), + train_bias=True, + loss=multiple_tensors_mse_loss, + rounding_type=RoundingType.SoftQuantizer), + GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam(), + optimizer_rest=tf.keras.optimizers.Adam(), + train_bias=True, + loss=multiple_tensors_mse_loss, + rounding_type=RoundingType.SoftQuantizer, + regularization_factor=15), + GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam(), + optimizer_rest=tf.keras.optimizers.Adam(), + train_bias=True, + loss=multiple_tensors_mse_loss, + rounding_type=RoundingType.SoftQuantizer, + gptq_quantizer_params_override={QUANT_PARAM_LEARNING_STR: True}), + GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam(), + optimizer_rest=tf.keras.optimizers.Adam(), + train_bias=True, + loss=multiple_tensors_mse_loss, + rounding_type=RoundingType.SoftQuantizer, + hessian_weights_config=test_hessian_weights_config), + GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam(), + optimizer_rest=tf.keras.optimizers.Adam(), + train_bias=True, + loss=multiple_tensors_mse_loss, + rounding_type=RoundingType.STE, + gptq_quantizer_params_override={ + MAX_LSB_STR: DefaultDict(default_value=1)}), + get_keras_gptq_config(n_epochs=1, + optimizer=tf.keras.optimizers.Adam()), + get_keras_gptq_config(n_epochs=1, + optimizer=tf.keras.optimizers.Adam(), + regularization_factor=0.001)] pot_tp = generate_test_tp_model({'weights_quantization_method': QuantizationMethod.POWER_OF_TWO}) @@ -123,9 +121,13 @@ def test_get_keras_gptq_config_pot(self): # This call removes the effect of @tf.function decoration and executes the decorated function eagerly, which # enabled tracing for code coverage. tf.config.run_functions_eagerly(True) - for i, gptq_config in enumerate(self.gptqv2_configurations): + for i, gptq_config in enumerate(self.gptq_configurations): + # Reducing the default number of samples for GPTQ Hessian approximation + # to allow quick execution of the test + gptq_config.hessian_weights_config.hessians_num_samples = 2 + keras_gradient_post_training_quantization(in_model=build_model(SHAPE[1:]), - representative_data_gen=random_datagen_experimental, + representative_data_gen=random_datagen, gptq_config=gptq_config, core_config=self.cc, target_platform_capabilities=self.pot_weights_tpc) @@ -136,9 +138,13 @@ def test_get_keras_gptq_config_symmetric(self): # enabled tracing for code coverage. tf.config.run_functions_eagerly(True) - for i, gptq_config in enumerate(self.gptqv2_configurations): + for i, gptq_config in enumerate(self.gptq_configurations): + # Reducing the default number of samples for GPTQ Hessian approximation + # to allow quick execution of the test + gptq_config.hessian_weights_config.hessians_num_samples = 2 + keras_gradient_post_training_quantization(in_model=build_model(SHAPE[1:]), - representative_data_gen=random_datagen_experimental, + representative_data_gen=random_datagen, gptq_config=gptq_config, core_config=self.cc, target_platform_capabilities=self.symmetric_weights_tpc) diff --git a/tests/keras_tests/function_tests/test_hessian_info_calculator.py b/tests/keras_tests/function_tests/test_hessian_info_calculator.py index 2360a35e9..c3c7d869f 100644 --- a/tests/keras_tests/function_tests/test_hessian_info_calculator.py +++ b/tests/keras_tests/function_tests/test_hessian_info_calculator.py @@ -24,6 +24,7 @@ import model_compression_toolkit as mct import model_compression_toolkit.core.common.hessian as hessian_common +from model_compression_toolkit.core.keras.constants import KERNEL from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc @@ -86,21 +87,39 @@ def representative_dataset(input_shape, num_of_inputs=1): yield [np.random.randn(*input_shape).astype(np.float32)] * num_of_inputs +def get_expected_shape(t_shape, granularity, node_type): + if granularity == hessian_common.HessianInfoGranularity.PER_ELEMENT: + return (1, *t_shape) + elif granularity == hessian_common.HessianInfoGranularity.PER_TENSOR: + return (1, 1) + else: + return (1, t_shape[-1] * t_shape[2]) if node_type == DepthwiseConv2D else \ + (1, t_shape[2]) if node_type == Conv2DTranspose else \ + (1, t_shape[-1]) + + class TestHessianInfoCalculatorBase(unittest.TestCase): - def _fetch_scores(self, hessian_info, target_node, granularity, mode, num_scores=1): + def _fetch_scores(self, hessian_info, target_nodes, granularity, mode, num_scores=1): request = hessian_common.TraceHessianRequest(mode=mode, granularity=granularity, - target_node=target_node) + target_nodes=target_nodes) info = hessian_info.fetch_hessian(request, num_scores) assert len(info) == num_scores, f"fetched {num_scores} score but {len(info)} scores were fetched" return np.mean(np.stack(info), axis=0) - def _test_score_shape(self, hessian_service, interest_point, granularity, mode, expected_shape, num_scores=1): + def _test_score_shape(self, hessian_service, interest_point, granularity, mode, num_scores=1): + score = self._fetch_scores(hessian_info=hessian_service, - target_node=interest_point, # linear op + target_nodes=[interest_point], # linear op granularity=granularity, mode=mode, num_scores=num_scores) + + kernel_attr_name = [w for w in interest_point.weights if KERNEL in w] + self.assertTrue(len(kernel_attr_name) == 1, "Expecting exactly 1 kernel attribute.") + expected_shape = ( + get_expected_shape(interest_point.weights[kernel_attr_name[0]].shape, granularity, interest_point.type)) + self.assertTrue(isinstance(score, np.ndarray), f"scores expected to be a numpy array but is {type(score)}") self.assertTrue(score.shape == expected_shape, f"Tensor shape is expected to be {expected_shape} but has shape {score.shape}") # per tensor @@ -126,99 +145,83 @@ def test_conv2d_granularity(self): graph, _repr_dataset, keras_impl = self._setup(layer=Conv2D(filters=2, kernel_size=3)) sorted_graph_nodes = graph.get_topo_sorted_nodes() interest_points = [n for n in sorted_graph_nodes] - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=_repr_dataset, + hessian_service = hessian_common.HessianInfoService(graph=graph, representative_dataset_gen=_repr_dataset, fw_impl=keras_impl) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(1,)) + mode=hessian_common.HessianMode.WEIGHTS) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(2,)) + mode=hessian_common.HessianMode.WEIGHTS) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(3, 3, 3, 2)) + mode=hessian_common.HessianMode.WEIGHTS) del hessian_service def test_dense_granularity(self): graph, _repr_dataset, keras_impl = self._setup(layer=Dense(2), input_shape=(1, 8)) sorted_graph_nodes = graph.get_topo_sorted_nodes() interest_points = [n for n in sorted_graph_nodes] - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=_repr_dataset, + hessian_service = hessian_common.HessianInfoService(graph=graph, representative_dataset_gen=_repr_dataset, fw_impl=keras_impl) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(1,)) + mode=hessian_common.HessianMode.WEIGHTS) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(2,)) + mode=hessian_common.HessianMode.WEIGHTS) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(8, 2)) + mode=hessian_common.HessianMode.WEIGHTS) del hessian_service def test_conv2dtranspose_granularity(self): graph, _repr_dataset, keras_impl = self._setup(layer=Conv2DTranspose(filters=2, kernel_size=3)) sorted_graph_nodes = graph.get_topo_sorted_nodes() interest_points = [n for n in sorted_graph_nodes] - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=_repr_dataset, + hessian_service = hessian_common.HessianInfoService(graph=graph, representative_dataset_gen=_repr_dataset, fw_impl=keras_impl) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(1,)) + mode=hessian_common.HessianMode.WEIGHTS) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(2,)) + mode=hessian_common.HessianMode.WEIGHTS) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(3, 3, 2, 3)) + mode=hessian_common.HessianMode.WEIGHTS) del hessian_service def test_depthwiseconv2d_granularity(self): graph, _repr_dataset, keras_impl = self._setup(layer=DepthwiseConv2D(kernel_size=3)) sorted_graph_nodes = graph.get_topo_sorted_nodes() interest_points = [n for n in sorted_graph_nodes] - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=_repr_dataset, + hessian_service = hessian_common.HessianInfoService(graph=graph, representative_dataset_gen=_repr_dataset, fw_impl=keras_impl) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(1,)) + mode=hessian_common.HessianMode.WEIGHTS) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(3,)) + mode=hessian_common.HessianMode.WEIGHTS) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(3, 3, 3, 1)) + mode=hessian_common.HessianMode.WEIGHTS) del hessian_service def test_reused_layer(self): @@ -238,35 +241,36 @@ def test_reused_layer(self): # Two nodes representing the same reused layer interest_points = [n for n in sorted_graph_nodes if n.is_match_type(Conv2D)] - self.assertTrue(len(interest_points)==2, f"Expected to find 2 Conv2D nodes but found {len(interest_points)}") + self.assertTrue(len(interest_points) == 2, f"Expected to find 2 Conv2D nodes but found {len(interest_points)}") - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=_repr_dataset, + hessian_service = hessian_common.HessianInfoService(graph=graph, representative_dataset_gen=_repr_dataset, fw_impl=keras_impl) node1_approx = self._test_score_shape(hessian_service, interest_points[0], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(1,)) + mode=hessian_common.HessianMode.WEIGHTS) node2_approx = self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(1,)) - self.assertTrue(np.all(node1_approx==node2_approx), f'Approximations of nodes of a reused layer ' - f'should be equal') + mode=hessian_common.HessianMode.WEIGHTS) + self.assertTrue(np.all(node1_approx == node2_approx), f'Approximations of nodes of a reused layer ' + f'should be equal') - node1_count = hessian_service.count_saved_info_of_request( - hessian_common.TraceHessianRequest(target_node=interest_points[0], - mode=hessian_common.HessianMode.WEIGHTS, - granularity=hessian_common.HessianInfoGranularity.PER_TENSOR)) - self.assertTrue(node1_count == 1) + # Expecting call for count_saved_info_of_request for reused node to be with a reconstructed request with the + # node's representative group member for its reuse group + with self.assertRaises(Exception) as e: + node1_count = hessian_service.count_saved_info_of_request( + hessian_common.TraceHessianRequest(target_nodes=[interest_points[0]], + mode=hessian_common.HessianMode.WEIGHTS, + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR)) + self.assertTrue("Expecting the Hessian request to include only non-reused nodes at this point" + in str(e.exception)) node2_count = hessian_service.count_saved_info_of_request( - hessian_common.TraceHessianRequest(target_node=interest_points[1], + hessian_common.TraceHessianRequest(target_nodes=[interest_points[1]], mode=hessian_common.HessianMode.WEIGHTS, granularity=hessian_common.HessianInfoGranularity.PER_TENSOR)) - self.assertTrue(node2_count == 1) + self.assertTrue(node2_count[interest_points[1]] == 1) self.assertTrue(len(hessian_service.trace_hessian_request_to_score_list) == 1) del hessian_service @@ -293,24 +297,20 @@ def _test_advanced_graph(self, float_model, _repr_dataset): # This test assumes the first Conv2D interest point is the node that # we fetch its scores and test their shapes correctness. interest_points = [n for n in sorted_graph_nodes if n.type==Conv2D][0] - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=_repr_dataset, + hessian_service = hessian_common.HessianInfoService(graph=graph, representative_dataset_gen=_repr_dataset, fw_impl=keras_impl) self._test_score_shape(hessian_service, interest_points, granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(1,)) + mode=hessian_common.HessianMode.WEIGHTS) self._test_score_shape(hessian_service, interest_points, granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(2,)) + mode=hessian_common.HessianMode.WEIGHTS) self._test_score_shape(hessian_service, interest_points, granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, - mode=hessian_common.HessianMode.WEIGHTS, - expected_shape=(3, 3, 3, 2)) + mode=hessian_common.HessianMode.WEIGHTS) del hessian_service @@ -344,14 +344,12 @@ def test_conv2d_granularity(self): graph, _repr_dataset, keras_impl = self._setup(layer=Conv2D(filters=2, kernel_size=3)) sorted_graph_nodes = graph.get_topo_sorted_nodes() interest_points = [n for n in sorted_graph_nodes] - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=_repr_dataset, + hessian_service = hessian_common.HessianInfoService(graph=graph, representative_dataset_gen=_repr_dataset, fw_impl=keras_impl) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - mode=hessian_common.HessianMode.ACTIVATION, - expected_shape=(1,)) + mode=hessian_common.HessianMode.ACTIVATION) del hessian_service @@ -359,15 +357,13 @@ def test_dense_granularity(self): graph, _repr_dataset, keras_impl = self._setup(layer=Dense(2), input_shape=(1, 8)) sorted_graph_nodes = graph.get_topo_sorted_nodes() interest_points = [n for n in sorted_graph_nodes] - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=_repr_dataset, + hessian_service = hessian_common.HessianInfoService(graph=graph, representative_dataset_gen=_repr_dataset, fw_impl=keras_impl) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - mode=hessian_common.HessianMode.ACTIVATION, - expected_shape=(1,)) + mode=hessian_common.HessianMode.ACTIVATION) del hessian_service @@ -375,15 +371,13 @@ def test_conv2dtranspose_granularity(self): graph, _repr_dataset, keras_impl = self._setup(layer=Conv2DTranspose(filters=2, kernel_size=3)) sorted_graph_nodes = graph.get_topo_sorted_nodes() interest_points = [n for n in sorted_graph_nodes] - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=_repr_dataset, + hessian_service = hessian_common.HessianInfoService(graph=graph, representative_dataset_gen=_repr_dataset, fw_impl=keras_impl) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - mode=hessian_common.HessianMode.ACTIVATION, - expected_shape=(1,)) + mode=hessian_common.HessianMode.ACTIVATION) del hessian_service @@ -391,15 +385,13 @@ def test_depthwiseconv2d_granularity(self): graph, _repr_dataset, keras_impl = self._setup(layer=DepthwiseConv2D(kernel_size=3)) sorted_graph_nodes = graph.get_topo_sorted_nodes() interest_points = [n for n in sorted_graph_nodes] - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=_repr_dataset, + hessian_service = hessian_common.HessianInfoService(graph=graph, representative_dataset_gen=_repr_dataset, fw_impl=keras_impl) self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - mode=hessian_common.HessianMode.ACTIVATION, - expected_shape=(1,)) + mode=hessian_common.HessianMode.ACTIVATION) del hessian_service @@ -422,34 +414,36 @@ def test_reused_layer(self): interest_points = [n for n in sorted_graph_nodes if n.is_match_type(Conv2D)] self.assertTrue(len(interest_points)==2, f"Expected to find 2 Conv2D nodes but found {len(interest_points)}") - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=_repr_dataset, + hessian_service = hessian_common.HessianInfoService(graph=graph, representative_dataset_gen=_repr_dataset, fw_impl=keras_impl) node1_approx = self._test_score_shape(hessian_service, interest_points[0], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - mode=hessian_common.HessianMode.ACTIVATION, - expected_shape=(1,)) + mode=hessian_common.HessianMode.ACTIVATION) node2_approx = self._test_score_shape(hessian_service, interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - mode=hessian_common.HessianMode.ACTIVATION, - expected_shape=(1,)) + mode=hessian_common.HessianMode.ACTIVATION) self.assertTrue(np.all(node1_approx == node2_approx), f'Approximations of nodes of a reused layer ' f'should be equal') - node1_count = hessian_service.count_saved_info_of_request( - hessian_common.TraceHessianRequest(target_node=interest_points[0], - mode=hessian_common.HessianMode.ACTIVATION, - granularity=hessian_common.HessianInfoGranularity.PER_TENSOR)) - self.assertTrue(node1_count == 1) + # Expecting call for count_saved_info_of_request for reused node to be with a reconstructed request with the + # node's representative group member for its reuse group + with self.assertRaises(Exception) as e: + node1_count = hessian_service.count_saved_info_of_request( + hessian_common.TraceHessianRequest(target_nodes=[interest_points[0]], + mode=hessian_common.HessianMode.ACTIVATION, + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR)) + self.assertTrue("Expecting the Hessian request to include only non-reused nodes at this point" + in str(e.exception)) node2_count = hessian_service.count_saved_info_of_request( - hessian_common.TraceHessianRequest(target_node=interest_points[1], + hessian_common.TraceHessianRequest(target_nodes=[interest_points[1]], mode=hessian_common.HessianMode.ACTIVATION, granularity=hessian_common.HessianInfoGranularity.PER_TENSOR)) - self.assertTrue(node2_count == 1) + self.assertTrue(node2_count[interest_points[1]] == 1) + self.assertTrue(len(hessian_service.trace_hessian_request_to_score_list) == 1) del hessian_service @@ -477,14 +471,12 @@ def _test_advanced_graph(self, float_model, _repr_dataset): # This test assumes the first Conv2D interest point is the node that # we fetch its scores and test their shapes correctness. interest_points = [n for n in sorted_graph_nodes if n.type==Conv2D][0] - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=_repr_dataset, + hessian_service = hessian_common.HessianInfoService(graph=graph, representative_dataset_gen=_repr_dataset, fw_impl=keras_impl) self._test_score_shape(hessian_service, interest_points, granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - mode=hessian_common.HessianMode.ACTIVATION, - expected_shape=(1,)) + mode=hessian_common.HessianMode.ACTIVATION) del hessian_service @@ -512,13 +504,12 @@ def test_multiple_outputs_to_intermediate_node(self): def test_activation_hessian_output_exception(self): graph, _repr_dataset, keras_impl = self._setup(layer=Conv2D(filters=2, kernel_size=3)) - hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=_repr_dataset, + hessian_service = hessian_common.HessianInfoService(graph=graph, representative_dataset_gen=_repr_dataset, fw_impl=keras_impl) with self.assertRaises(Exception) as e: request = hessian_common.TraceHessianRequest(granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, mode=hessian_common.HessianMode.ACTIVATION, - target_node=graph.get_outputs()[0].node) + target_nodes=[graph.get_outputs()[0].node]) _ = hessian_service.fetch_hessian(request, required_size=1) self.assertTrue("Trying to compute activation Hessian approximation with respect to the model output" diff --git a/tests/keras_tests/function_tests/test_hessian_service.py b/tests/keras_tests/function_tests/test_hessian_service.py index 0935f350b..02bbb7b68 100644 --- a/tests/keras_tests/function_tests/test_hessian_service.py +++ b/tests/keras_tests/function_tests/test_hessian_service.py @@ -39,8 +39,21 @@ def basic_model(input_shape): return keras.Model(inputs=inputs, outputs=outputs) -def representative_dataset(num_of_inputs=1): - yield [np.random.randn(2, 8, 8, 3).astype(np.float32)] * num_of_inputs +def multiple_act_nodes_model(input_shape): + random_uniform = initializers.random_uniform(0, 1) + inputs = Input(shape=input_shape) + x = Conv2D(2, 3, padding='same', name="conv2d")(inputs) + x_bn = BatchNormalization(gamma_initializer='random_normal', beta_initializer='random_normal', + moving_mean_initializer='random_normal', moving_variance_initializer=random_uniform, + name="bn1")(x) + x_relu = ReLU()(x_bn) + outputs = Conv2D(2, 3, padding='same', name="conv2d_2")(x_relu) + return keras.Model(inputs=inputs, outputs=outputs) + + +def representative_dataset(): + for _ in range(2): + yield [np.random.randn(2, 8, 8, 3).astype(np.float32)] class TestHessianService(unittest.TestCase): @@ -58,53 +71,139 @@ def setUp(self): representative_dataset, generate_keras_tpc) + self.hessian_service = HessianInfoService(graph=self.graph, representative_dataset_gen=representative_dataset, + fw_impl=self.keras_impl) + + self.assertEqual(self.hessian_service.graph, self.graph) + self.assertEqual(self.hessian_service.fw_impl, self.keras_impl) + + def test_fetch_activation_hessian(self): + request = TraceHessianRequest(mode=HessianMode.ACTIVATION, + granularity=HessianInfoGranularity.PER_TENSOR, + target_nodes=[list(self.graph.get_topo_sorted_nodes())[0]]) + hessian = self.hessian_service.fetch_hessian(request, 2) + self.assertEqual(len(hessian), 1, "Expecting returned Hessian list to include one list of " + "approximation, for the single target node.") + self.assertEqual(len(hessian[0]), 2, "Expecting 2 Hessian scores.") + + def test_fetch_weights_hessian(self): + request = TraceHessianRequest(mode=HessianMode.WEIGHTS, + granularity=HessianInfoGranularity.PER_OUTPUT_CHANNEL, + target_nodes=[list(self.graph.get_topo_sorted_nodes())[1]]) + hessian = self.hessian_service.fetch_hessian(request, 2) + self.assertEqual(len(hessian), 1, "Expecting returned Hessian list to include one list of " + "approximation, for the single target node.") + self.assertEqual(len(hessian[0]), 2, "Expecting 2 Hessian scores.") + + def test_fetch_not_enough_samples_throw(self): + request = TraceHessianRequest(mode=HessianMode.ACTIVATION, + granularity=HessianInfoGranularity.PER_TENSOR, + target_nodes=[list(self.graph.get_topo_sorted_nodes())[0]]) + + with self.assertRaises(Exception) as e: + hessian = self.hessian_service.fetch_hessian(request, 5, batch_size=2) # representative dataset produces 4 images total + + self.assertTrue('Not enough samples in the provided representative dataset' in str(e.exception)) + + def test_fetch_not_enough_samples_small_batch_throw(self): + request = TraceHessianRequest(mode=HessianMode.ACTIVATION, + granularity=HessianInfoGranularity.PER_TENSOR, + target_nodes=[list(self.graph.get_topo_sorted_nodes())[0]]) + + with self.assertRaises(Exception) as e: + hessian = self.hessian_service.fetch_hessian(request, 5, batch_size=1) # representative dataset produces 4 images total + + self.assertTrue('Not enough samples in the provided representative dataset' in str(e.exception)) + + def test_fetch_compute_batch_larger_than_repr_batch(self): + request = TraceHessianRequest(mode=HessianMode.ACTIVATION, + granularity=HessianInfoGranularity.PER_TENSOR, + target_nodes=[list(self.graph.get_topo_sorted_nodes())[0]]) + + hessian = self.hessian_service.fetch_hessian(request, 3, batch_size=3) # representative batch size is 2 + self.assertEqual(len(hessian), 1, "Expecting returned Hessian list to include one list of " + "approximation, for the single target node.") + self.assertEqual(len(hessian[0]), 3, "Expecting 3 Hessian scores.") + + def test_fetch_required_zero(self): + request = TraceHessianRequest(mode=HessianMode.ACTIVATION, + granularity=HessianInfoGranularity.PER_TENSOR, + target_nodes=[list(self.graph.get_topo_sorted_nodes())[0]],) + + hessian = self.hessian_service.fetch_hessian(request, 0) + + self.assertEqual(len(hessian), 1, "Expecting returned Hessian list to include one list of " + "approximation, for the single target node.") + self.assertEqual(len(hessian[0]), 0, "Expecting an empty Hessian scores list.") + + def test_fetch_multiple_nodes(self): + input_shape = (8, 8, 3) + self.float_model = multiple_act_nodes_model(input_shape) + self.keras_impl = KerasImplementation() + self.graph = prepare_graph_with_configs(self.float_model, + self.keras_impl, + DEFAULT_KERAS_INFO, + representative_dataset, + generate_keras_tpc) + self.hessian_service = HessianInfoService(graph=self.graph, - representative_dataset=representative_dataset, + representative_dataset_gen=representative_dataset, fw_impl=self.keras_impl) self.assertEqual(self.hessian_service.graph, self.graph) self.assertEqual(self.hessian_service.fw_impl, self.keras_impl) - def test_fetch_hessian(self): + graph_nodes = list(self.graph.get_topo_sorted_nodes()) request = TraceHessianRequest(mode=HessianMode.ACTIVATION, granularity=HessianInfoGranularity.PER_TENSOR, - target_node=list(self.graph.nodes)[1]) + target_nodes=[graph_nodes[0], graph_nodes[2]]) + hessian = self.hessian_service.fetch_hessian(request, 2) - self.assertEqual(len(hessian), 2) + + self.assertEqual(len(hessian), 2, "Expecting returned Hessian list to include two list of " + "approximation, for the two target nodes.") + self.assertEqual(len(hessian[0]), 2, f"Expecting 2 Hessian scores for layer {graph_nodes[0].name}.") + self.assertEqual(len(hessian[1]), 2, f"Expecting 2 Hessian scores for layer {graph_nodes[2].name}.") def test_clear_cache(self): self.hessian_service._clear_saved_hessian_info() + target_node = list(self.graph.nodes)[1] request = TraceHessianRequest(mode=HessianMode.ACTIVATION, granularity=HessianInfoGranularity.PER_TENSOR, - target_node=list(self.graph.nodes)[1]) - self.assertEqual(self.hessian_service.count_saved_info_of_request(request), 0) + target_nodes=[target_node]) + self.assertEqual(self.hessian_service.count_saved_info_of_request(request)[target_node], 0) self.hessian_service.fetch_hessian(request, 1) - self.assertEqual(self.hessian_service.count_saved_info_of_request(request), 1) + self.assertEqual(self.hessian_service.count_saved_info_of_request(request)[target_node], 1) self.hessian_service._clear_saved_hessian_info() - self.assertEqual(self.hessian_service.count_saved_info_of_request(request), 0) - + self.assertEqual(self.hessian_service.count_saved_info_of_request(request)[target_node], 0) def test_double_fetch_hessian(self): self.hessian_service._clear_saved_hessian_info() + target_node = list(self.graph.nodes)[1] request = TraceHessianRequest(mode=HessianMode.ACTIVATION, granularity=HessianInfoGranularity.PER_TENSOR, - target_node=list(self.graph.nodes)[1]) + target_nodes=[target_node]) hessian = self.hessian_service.fetch_hessian(request, 2) - self.assertEqual(len(hessian), 2) - self.assertEqual(self.hessian_service.count_saved_info_of_request(request), 2) + self.assertEqual(len(hessian), 1, "Expecting returned Hessian list to include one list of " + "approximation, for the single target node.") + self.assertEqual(len(hessian[0]), 2, "Expecting 2 Hessian scores.") + self.assertEqual(self.hessian_service.count_saved_info_of_request(request)[target_node], 2) hessian = self.hessian_service.fetch_hessian(request, 2) - self.assertEqual(len(hessian), 2) - self.assertEqual(self.hessian_service.count_saved_info_of_request(request), 2) + self.assertEqual(len(hessian), 1, "Expecting returned Hessian list to include one list of " + "approximation, for the single target node.") + self.assertEqual(len(hessian[0]), 2, "Expecting 2 Hessian scores.") + self.assertEqual(self.hessian_service.count_saved_info_of_request(request)[target_node], 2) def test_populate_cache_to_size(self): self.hessian_service._clear_saved_hessian_info() + target_node = list(self.graph.nodes)[1] request = TraceHessianRequest(mode=HessianMode.ACTIVATION, granularity=HessianInfoGranularity.PER_TENSOR, - target_node=list(self.graph.nodes)[1]) + target_nodes=[target_node]) self.hessian_service._populate_saved_info_to_size(request, 2) - self.assertEqual(self.hessian_service.count_saved_info_of_request(request), 2) + self.assertEqual(self.hessian_service.count_saved_info_of_request(request)[target_node], 2) if __name__ == "__main__": diff --git a/tests/keras_tests/function_tests/test_hmse_error_method.py b/tests/keras_tests/function_tests/test_hmse_error_method.py index 6e6dc806a..291383148 100644 --- a/tests/keras_tests/function_tests/test_hmse_error_method.py +++ b/tests/keras_tests/function_tests/test_hmse_error_method.py @@ -89,8 +89,7 @@ def _setup_with_args(self, quant_method, per_channel, running_gptq=True, tpc_fn= running_gptq=running_gptq # to enable HMSE in params calculation if needed ) - self.his = HessianInfoService(graph=self.graph, - representative_dataset=representative_dataset, + self.his = HessianInfoService(graph=self.graph, representative_dataset_gen=representative_dataset, fw_impl=self.keras_impl) mi = ModelCollector(self.graph, @@ -114,9 +113,9 @@ def _run_node_verification(node_type): expected_hessian_request = TraceHessianRequest(mode=HessianMode.WEIGHTS, granularity=HessianInfoGranularity.PER_ELEMENT, - target_node=node) + target_nodes=[node]) - self.assertTrue(self.his.count_saved_info_of_request(expected_hessian_request) > 0, + self.assertTrue(self.his.count_saved_info_of_request(expected_hessian_request)[node] > 0, f"No Hessian-based scores were computed for node {node}, " "but expected parameters selection to run with HMSE.") @@ -246,9 +245,9 @@ def _generate_bn_quantization_tpc(quant_method, per_channel): expected_hessian_request = TraceHessianRequest(mode=HessianMode.WEIGHTS, granularity=HessianInfoGranularity.PER_ELEMENT, - target_node=node) + target_nodes=[node]) - self.assertTrue(self.his.count_saved_info_of_request(expected_hessian_request) == 0, + self.assertTrue(self.his.count_saved_info_of_request(expected_hessian_request)[node] == 0, f"Hessian-based scores were computed for node {node}, " "but expected parameters selection to run with MSE without computing Hessians.") diff --git a/tests/keras_tests/function_tests/test_sensitivity_eval_non_suppoerted_output.py b/tests/keras_tests/function_tests/test_sensitivity_eval_non_suppoerted_output.py index 209b2ac3b..6d65215b0 100644 --- a/tests/keras_tests/function_tests/test_sensitivity_eval_non_suppoerted_output.py +++ b/tests/keras_tests/function_tests/test_sensitivity_eval_non_suppoerted_output.py @@ -67,8 +67,10 @@ def nms_output_model(input_shape): model = keras.Model(inputs=inputs, outputs=outputs) return model + def representative_dataset(): - yield [np.random.randn(1, 8, 8, 3).astype(np.float32)] + for _ in range(2): + yield [np.random.randn(1, 8, 8, 3).astype(np.float32)] class TestSensitivityEvalWithNonSupportedOutputNodes(unittest.TestCase): @@ -83,12 +85,14 @@ def verify_test_for_model(self, model): input_shape=(1, 8, 8, 3), mixed_precision_enabled=True) - hessian_info_service = hess.HessianInfoService(graph=graph, - representative_dataset=representative_dataset, + hessian_info_service = hess.HessianInfoService(graph=graph, representative_dataset_gen=representative_dataset, fw_impl=keras_impl) + # Reducing the default number of samples for Mixed precision Hessian approximation + # to allow quick execution of the test se = keras_impl.get_sensitivity_evaluator(graph, - MixedPrecisionQuantizationConfig(use_hessian_based_scores=True), + MixedPrecisionQuantizationConfig(use_hessian_based_scores=True, + num_of_images=2), representative_dataset, DEFAULT_KERAS_INFO, hessian_info_service=hessian_info_service) diff --git a/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2d_conv2dtranspose_pruning_test.py b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2d_conv2dtranspose_pruning_test.py index 3bd51ddb3..fbd92e8b6 100644 --- a/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2d_conv2dtranspose_pruning_test.py +++ b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2d_conv2dtranspose_pruning_test.py @@ -65,7 +65,8 @@ def get_tpc(self): def get_pruning_config(self): if self.use_constant_importance_metric: add_const_importance_metric(first_num_oc=6, second_num_oc=4, simd=self.simd) - return mct.pruning.PruningConfig(importance_metric=ConstImportanceMetric.CONST) + return mct.pruning.PruningConfig(importance_metric=ConstImportanceMetric.CONST, + num_score_approximations=super().get_pruning_config().num_score_approximations) return super().get_pruning_config() def get_resource_utilization(self): # Remove only one group of channels only one parameter should be pruned diff --git a/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2dtranspose_pruning_test.py b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2dtranspose_pruning_test.py index aaf1ccf3a..4ec74bf08 100644 --- a/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2dtranspose_pruning_test.py +++ b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2dtranspose_pruning_test.py @@ -67,7 +67,8 @@ def get_tpc(self): def get_pruning_config(self): if self.use_constant_importance_metric: add_const_importance_metric(first_num_oc=6, second_num_oc=4, simd=self.simd) - return mct.pruning.PruningConfig(importance_metric=ConstImportanceMetric.CONST) + return mct.pruning.PruningConfig(importance_metric=ConstImportanceMetric.CONST, + num_score_approximations=super().get_pruning_config().num_score_approximations) return super().get_pruning_config() def get_resource_utilization(self): diff --git a/tests/keras_tests/pruning_tests/feature_networks/pruning_keras_feature_test.py b/tests/keras_tests/pruning_tests/feature_networks/pruning_keras_feature_test.py index ee4149ec0..ab4d1c1f3 100644 --- a/tests/keras_tests/pruning_tests/feature_networks/pruning_keras_feature_test.py +++ b/tests/keras_tests/pruning_tests/feature_networks/pruning_keras_feature_test.py @@ -20,7 +20,7 @@ class PruningKerasFeatureTest(BaseKerasFeatureNetworkTest): def __init__(self, unit_test, - num_calibration_iter=1, + num_calibration_iter=2, val_batch_size=1, num_of_inputs=1, input_shape=(8, 8, 3)): diff --git a/tests/pytorch_tests/function_tests/get_gptq_config_test.py b/tests/pytorch_tests/function_tests/get_gptq_config_test.py index 05650ef42..0b8aa03d1 100644 --- a/tests/pytorch_tests/function_tests/get_gptq_config_test.py +++ b/tests/pytorch_tests/function_tests/get_gptq_config_test.py @@ -49,7 +49,7 @@ def forward(self, inp): return x -def random_datagen_experimental(): +def random_datagen(): for _ in range(20): yield [np.random.random((1, 3, 8, 8))] @@ -69,20 +69,24 @@ def run_test(self): weights_bias_correction=False) # disable bias correction when working with GPTQ cc = CoreConfig(quantization_config=qc) - gptqv2_config = get_pytorch_gptq_config(n_epochs=1, - optimizer=torch.optim.Adam([torch.Tensor([])], lr=1e-4), - regularization_factor=0.001) - gptqv2_config.rounding_type = self.rounding_type - gptqv2_config.train_bias = self.train_bias + gptq_config = get_pytorch_gptq_config(n_epochs=1, + optimizer=torch.optim.Adam([torch.Tensor([])], lr=1e-4), + regularization_factor=0.001) + + # Decreasing the default number of samples for GPTQ Hessian approximation to allow quick execution of the test + gptq_config.hessian_weights_config.hessians_num_samples = 2 + + gptq_config.rounding_type = self.rounding_type + gptq_config.train_bias = self.train_bias if self.rounding_type == RoundingType.SoftQuantizer: - gptqv2_config.gptq_quantizer_params_override = \ + gptq_config.gptq_quantizer_params_override = \ {QUANT_PARAM_LEARNING_STR: self.quantization_parameters_learning} elif self.rounding_type == RoundingType.STE: - gptqv2_config.gptq_quantizer_params_override = \ + gptq_config.gptq_quantizer_params_override = \ {MAX_LSB_STR: DefaultDict(default_value=1)} else: - gptqv2_config.gptq_quantizer_params_override = None + gptq_config.gptq_quantizer_params_override = None tp = generate_test_tp_model({'weights_quantization_method': self.quantization_method}) symmetric_weights_tpc = generate_pytorch_tpc(name="gptq_config_test", tp_model=tp) @@ -90,7 +94,7 @@ def run_test(self): float_model = TestModel() quant_model, _ = pytorch_gradient_post_training_quantization(model=float_model, - representative_data_gen=random_datagen_experimental, + representative_data_gen=random_datagen, core_config=cc, - gptq_config=gptqv2_config, + gptq_config=gptq_config, target_platform_capabilities=symmetric_weights_tpc) diff --git a/tests/pytorch_tests/function_tests/test_function_runner.py b/tests/pytorch_tests/function_tests/test_function_runner.py index 1d2fb2715..c88a3d57c 100644 --- a/tests/pytorch_tests/function_tests/test_function_runner.py +++ b/tests/pytorch_tests/function_tests/test_function_runner.py @@ -27,6 +27,10 @@ from tests.pytorch_tests.function_tests.set_device_test import SetDeviceTest from tests.pytorch_tests.function_tests.set_layer_to_bitwidth_test import TestSetLayerToBitwidthWeights, \ TestSetLayerToBitwidthActivation +from tests.pytorch_tests.function_tests.test_hessian_service import FetchActivationHessianTest, FetchWeightsHessianTest, \ + FetchHessianNotEnoughSamplesThrowTest, FetchHessianNotEnoughSamplesSmallBatchThrowTest, \ + FetchComputeBatchLargerThanReprBatchTest, FetchHessianRequiredZeroTest, FetchHessianMultipleNodesTest, \ + DoubleFetchHessianTest from tests.pytorch_tests.function_tests.test_sensitivity_eval_non_supported_output import \ TestSensitivityEvalWithArgmaxNode from tests.pytorch_tests.function_tests.test_hessian_info_calculator import WeightsHessianTraceBasicModelTest, \ @@ -34,7 +38,7 @@ WeightsHessianTraceMultipleOutputsModelTest, WeightsHessianTraceReuseModelTest, \ ActivationHessianTraceBasicModelTest, ActivationHessianTraceAdvanceModelTest, \ ActivationHessianTraceMultipleOutputsModelTest, ActivationHessianTraceReuseModelTest, \ - ActivationHessianOutputExceptionTest + ActivationHessianOutputExceptionTest, ActivationHessianTraceMultipleInputsModelTest class FunctionTestRunner(unittest.TestCase): @@ -111,6 +115,7 @@ def test_activation_hessian_trace(self): ActivationHessianTraceMultipleOutputsModelTest(self).run_test() ActivationHessianTraceReuseModelTest(self).run_test() ActivationHessianOutputExceptionTest(self).run_test() + ActivationHessianTraceMultipleInputsModelTest(self).run_test() def test_weights_hessian_trace(self): """ @@ -121,6 +126,19 @@ def test_weights_hessian_trace(self): WeightsHessianTraceMultipleOutputsModelTest(self).run_test() WeightsHessianTraceReuseModelTest(self).run_test() + def test_hessian_service(self): + """ + This test checks the Hessian service features with pytorch computation workflow. + """ + FetchActivationHessianTest(self).run_test() + FetchWeightsHessianTest(self).run_test() + FetchHessianNotEnoughSamplesThrowTest(self).run_test() + FetchHessianNotEnoughSamplesSmallBatchThrowTest(self).run_test() + FetchComputeBatchLargerThanReprBatchTest(self).run_test() + FetchHessianRequiredZeroTest(self).run_test() + FetchHessianMultipleNodesTest(self).run_test() + DoubleFetchHessianTest(self).run_test() + def test_layer_fusing(self): """ This test checks the Fusion mechanism in Pytorch. diff --git a/tests/pytorch_tests/function_tests/test_hessian_info_calculator.py b/tests/pytorch_tests/function_tests/test_hessian_info_calculator.py index ee643a2a9..61c41f058 100644 --- a/tests/pytorch_tests/function_tests/test_hessian_info_calculator.py +++ b/tests/pytorch_tests/function_tests/test_hessian_info_calculator.py @@ -16,6 +16,7 @@ import torch from torch.nn import Conv2d, BatchNorm2d, ReLU, Linear, Hardswish +from model_compression_toolkit.core.pytorch.constants import KERNEL from model_compression_toolkit.core.pytorch.utils import to_torch_tensor import numpy as np @@ -91,6 +92,18 @@ def forward(self, inp): return x1, x2, x3 +class multiple_inputs_model(torch.nn.Module): + def __init__(self): + super(multiple_inputs_model, self).__init__() + self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1) + self.conv2 = Conv2d(3, 3, kernel_size=1, stride=1) + + def forward(self, inp1, inp2): + x1 = self.conv1(inp1) + x2 = self.conv2(inp2) + return x1 + x2 + + class reused_model(torch.nn.Module): def __init__(self): super(reused_model, self).__init__() @@ -111,30 +124,25 @@ def forward(self, inp): def generate_inputs(inputs_shape): - inputs = [] - for in_shape in inputs_shape: - t = torch.randn(*in_shape) - t.requires_grad_() - inputs.append(t) - inputs = to_torch_tensor(inputs) - return inputs + return [1 + np.random.random(in_shape) for in_shape in inputs_shape] def get_expected_shape(t_shape, granularity): if granularity == hessian_common.HessianInfoGranularity.PER_ELEMENT: - return t_shape + return (1, *t_shape) elif granularity == hessian_common.HessianInfoGranularity.PER_TENSOR: - return (1,) + return (1, 1) else: - return (t_shape[0],) + return (1, t_shape[0]) class BaseHessianTraceBasicModelTest(BasePytorchTest): - def __init__(self, unit_test, model): + def __init__(self, unit_test, model, n_iters=2): super().__init__(unit_test) self.val_batch_size = 1 self.model = model + self.n_iters = n_iters def create_inputs_shape(self): return [[self.val_batch_size, 3, 8, 8]] @@ -145,25 +153,31 @@ def generate_inputs(input_shapes): def representative_data_gen(self): input_shapes = self.create_inputs_shape() - yield self.generate_inputs(input_shapes) + for _ in range(self.n_iters): + yield self.generate_inputs(input_shapes) def test_hessian_trace_approx(self, hessian_service, interest_point, mode, granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL, - num_scores=1): + num_scores=1, + batch_size=1): request = hessian_common.TraceHessianRequest(mode=mode, granularity=granularity, - target_node=interest_point) - expected_shape = get_expected_shape(interest_point.weights['weight'].shape, granularity) - info = hessian_service.fetch_hessian(request, num_scores) - assert len(info) == num_scores, f"fetched {num_scores} score but {len(info)} scores were fetched" - score = np.mean(np.stack(info), axis=0) + target_nodes=[interest_point]) + expected_shape = get_expected_shape(interest_point.weights[KERNEL].shape, granularity) + info = hessian_service.fetch_hessian(request, num_scores, batch_size=batch_size) + # The call for fetch_hessian returns the requested number of scores for each target node. + # Since in this test we request computation for a single node, we need to extract its results from the list. self.unit_test.assertTrue(isinstance(info, list)) + + info = info[0] self.unit_test.assertTrue(len(info) == num_scores, - f"fetched {num_scores} score but {len(info)} scores were fetched") + f"Requested {num_scores} score but {len(info)} scores were fetched") + + score = np.mean(np.stack(info), axis=0, keepdims=True) self.unit_test.assertTrue(score.shape == expected_shape, f"Tensor shape is expected to be {expected_shape} but has shape {score.shape}") @@ -184,7 +198,7 @@ def __init__(self, unit_test): def run_test(self, seed=0): graph, pytorch_impl = self._setup() hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=self.representative_data_gen, + representative_dataset_gen=self.representative_data_gen, fw_impl=pytorch_impl) ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0] for ipt in ipts: @@ -204,13 +218,13 @@ def run_test(self, seed=0): class WeightsHessianTraceAdvanceModelTest(BaseHessianTraceBasicModelTest): def __init__(self, unit_test): - super().__init__(unit_test, model=advanced_model) + super().__init__(unit_test, model=advanced_model, n_iters=3) self.val_batch_size = 2 def run_test(self, seed=0): graph, pytorch_impl = self._setup() hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=self.representative_data_gen, + representative_dataset_gen=self.representative_data_gen, fw_impl=pytorch_impl) ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0] for ipt in ipts: @@ -233,13 +247,13 @@ def run_test(self, seed=0): class WeightsHessianTraceMultipleOutputsModelTest(BaseHessianTraceBasicModelTest): def __init__(self, unit_test): - super().__init__(unit_test, model=multiple_outputs_model) + super().__init__(unit_test, model=multiple_outputs_model, n_iters=3) self.val_batch_size = 1 def run_test(self, seed=0): graph, pytorch_impl = self._setup() hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=self.representative_data_gen, + representative_dataset_gen=self.representative_data_gen, fw_impl=pytorch_impl) ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0] for ipt in ipts: @@ -262,13 +276,13 @@ def run_test(self, seed=0): class WeightsHessianTraceReuseModelTest(BaseHessianTraceBasicModelTest): def __init__(self, unit_test): - super().__init__(unit_test, model=reused_model) + super().__init__(unit_test, model=reused_model, n_iters=3) self.val_batch_size = 1 def run_test(self, seed=0): graph, pytorch_impl = self._setup() hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=self.representative_data_gen, + representative_dataset_gen=self.representative_data_gen, fw_impl=pytorch_impl) ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0] for ipt in ipts: @@ -297,7 +311,7 @@ def __init__(self, unit_test): def run_test(self, seed=0): graph, pytorch_impl = self._setup() hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=self.representative_data_gen, + representative_dataset_gen=self.representative_data_gen, fw_impl=pytorch_impl) ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0] for ipt in ipts: @@ -315,7 +329,7 @@ def __init__(self, unit_test): def run_test(self, seed=0): graph, pytorch_impl = self._setup() hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=self.representative_data_gen, + representative_dataset_gen=self.representative_data_gen, fw_impl=pytorch_impl) # removing last layer cause we do not allow activation Hessian computation for the output layer @@ -324,6 +338,7 @@ def run_test(self, seed=0): self.test_hessian_trace_approx(hessian_service, interest_point=ipt, num_scores=2, + batch_size=2, granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, mode=hessian_common.HessianMode.ACTIVATION) @@ -336,7 +351,7 @@ def __init__(self, unit_test): def run_test(self, seed=0): graph, pytorch_impl = self._setup() hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=self.representative_data_gen, + representative_dataset_gen=self.representative_data_gen, fw_impl=pytorch_impl) # removing last layer cause we do not allow activation Hessian computation for the output layer @@ -357,7 +372,7 @@ def __init__(self, unit_test): def run_test(self, seed=0): graph, pytorch_impl = self._setup() hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=self.representative_data_gen, + representative_dataset_gen=self.representative_data_gen, fw_impl=pytorch_impl) ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0] @@ -368,6 +383,7 @@ def run_test(self, seed=0): granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, mode=hessian_common.HessianMode.ACTIVATION) + class ActivationHessianOutputExceptionTest(BaseHessianTraceBasicModelTest): def __init__(self, unit_test): super().__init__(unit_test, model=basic_model) @@ -376,14 +392,37 @@ def __init__(self, unit_test): def run_test(self, seed=0): graph, pytorch_impl = self._setup() hessian_service = hessian_common.HessianInfoService(graph=graph, - representative_dataset=self.representative_data_gen, + representative_dataset_gen=self.representative_data_gen, fw_impl=pytorch_impl) with self.unit_test.assertRaises(Exception) as e: request = hessian_common.TraceHessianRequest(mode=hessian_common.HessianMode.ACTIVATION, granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, - target_node=graph.get_outputs()[0].node) + target_nodes=[graph.get_outputs()[0].node]) _ = hessian_service.fetch_hessian(request, required_size=1) self.unit_test.assertTrue("Activation Hessian approximation cannot be computed for model outputs" - in str(e.exception)) \ No newline at end of file + in str(e.exception)) + + +class ActivationHessianTraceMultipleInputsModelTest(BaseHessianTraceBasicModelTest): + def __init__(self, unit_test): + super().__init__(unit_test, model=multiple_inputs_model) + self.val_batch_size = 3 + + def create_inputs_shape(self): + return [[self.val_batch_size, 3, 8, 8], [self.val_batch_size, 3, 8, 8]] + + def run_test(self, seed=0): + graph, pytorch_impl = self._setup() + hessian_service = hessian_common.HessianInfoService(graph=graph, + representative_dataset_gen=self.representative_data_gen, + fw_impl=pytorch_impl) + ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0] + for ipt in ipts: + self.test_hessian_trace_approx(hessian_service, + interest_point=ipt, + granularity=hessian_common.HessianInfoGranularity.PER_TENSOR, + mode=hessian_common.HessianMode.ACTIVATION, + num_scores=3, + batch_size=2) diff --git a/tests/pytorch_tests/function_tests/test_hessian_service.py b/tests/pytorch_tests/function_tests/test_hessian_service.py new file mode 100644 index 000000000..148a16ab7 --- /dev/null +++ b/tests/pytorch_tests/function_tests/test_hessian_service.py @@ -0,0 +1,297 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import unittest + +from torch import nn +import numpy as np + +from model_compression_toolkit.core.common.hessian import HessianInfoService, TraceHessianRequest, HessianMode, \ + HessianInfoGranularity +from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO +from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc +from tests.common_tests.helpers.prep_graph_for_func_test import prepare_graph_with_configs +from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest + + +class BasicModel(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, kernel_size=(3, 3)) + self.bn = nn.BatchNorm2d(3) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + y = self.relu(x) + return y + + +class MultipleActNodesModel(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 3, kernel_size=(3, 3)) + self.conv2 = nn.Conv2d(3, 3, kernel_size=(3, 3)) + self.bn = nn.BatchNorm2d(3) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv1(x) + x = self.bn(x) + x = self.relu(x) + y = self.conv2(x) + return y + + +def representative_dataset(): + for _ in range(2): + yield [np.random.randn(2, 3, 8, 8).astype(np.float32)] + + +class BaseHessianServiceTest(BasePytorchTest): + + def __init__(self, unit_test, model, compute_hessian=True, run_verification=True): + super().__init__(unit_test) + + self.float_model = model() + self.request = None + self.num_scores = None + self.num_nodes = None + self.graph = None + self.pytorch_impl = PytorchImplementation() + self.run_verification = run_verification + self.compute_hessian = compute_hessian + + def verify_hessian(self): + self.unit_test.assertEqual(len(self.hessian), self.num_nodes, f"Expecting returned Hessian list to include " + f"{self.num_nodes} list of approximation.") + + for i in range(len(self.request.target_nodes)): + self.unit_test.assertEqual(len(self.hessian[i]), self.num_scores, + f"Expecting {self.num_scores} Hessian scores.") + + def run_test(self, seed=0): + # This is just an internal assertion for the test setup + assert (self.request is not None and self.num_scores is not None and self.num_nodes is not None + and self.graph is not None), "Test parameters are not initialized." + + self.hessian_service = HessianInfoService(graph=self.graph, representative_dataset_gen=representative_dataset, + fw_impl=self.pytorch_impl) + + self.unit_test.assertEqual(self.hessian_service.graph, self.graph) + self.unit_test.assertEqual(self.hessian_service.fw_impl, self.pytorch_impl) + + if self.compute_hessian: + self.hessian = self.hessian_service.fetch_hessian(self.request, self.num_scores) + + if self.run_verification: + self.verify_hessian() + + +class FetchActivationHessianTest(BaseHessianServiceTest): + def __init__(self, unit_test): + super().__init__(unit_test, model=BasicModel) + + self.num_nodes = 1 + self.num_scores = 2 + + def run_test(self, seed=0): + self.graph = prepare_graph_with_configs(self.float_model, + self.pytorch_impl, + DEFAULT_PYTORCH_INFO, + representative_dataset, + generate_pytorch_tpc) + + self.request = TraceHessianRequest(mode=HessianMode.ACTIVATION, + granularity=HessianInfoGranularity.PER_TENSOR, + target_nodes=[list(self.graph.get_topo_sorted_nodes())[0]]) + + super().run_test() + + +class FetchWeightsHessianTest(BaseHessianServiceTest): + def __init__(self, unit_test): + super().__init__(unit_test, model=BasicModel) + + self.num_nodes = 1 + self.num_scores = 2 + + def run_test(self, seed=0): + self.graph = prepare_graph_with_configs(self.float_model, + self.pytorch_impl, + DEFAULT_PYTORCH_INFO, + representative_dataset, + generate_pytorch_tpc) + + self.request = TraceHessianRequest(mode=HessianMode.WEIGHTS, + granularity=HessianInfoGranularity.PER_OUTPUT_CHANNEL, + target_nodes=[list(self.graph.get_topo_sorted_nodes())[1]]) + + super().run_test() + + +class FetchHessianNotEnoughSamplesThrowTest(BaseHessianServiceTest): + def __init__(self, unit_test): + super().__init__(unit_test, model=BasicModel, compute_hessian=False, run_verification=False) + + self.num_nodes = 1 + self.num_scores = 5 + + def run_test(self, seed=0): + self.graph = prepare_graph_with_configs(self.float_model, + self.pytorch_impl, + DEFAULT_PYTORCH_INFO, + representative_dataset, + generate_pytorch_tpc) + + self.request = TraceHessianRequest(mode=HessianMode.ACTIVATION, + granularity=HessianInfoGranularity.PER_TENSOR, + target_nodes=[list(self.graph.get_topo_sorted_nodes())[0]]) + + super().run_test() + + with self.unit_test.assertRaises(Exception) as e: + hessian = self.hessian_service.fetch_hessian(self.request, self.num_scores, batch_size=2) # representative dataset produces 4 images total + + self.unit_test.assertTrue('Not enough samples in the provided representative dataset' in str(e.exception)) + + +class FetchHessianNotEnoughSamplesSmallBatchThrowTest(BaseHessianServiceTest): + def __init__(self, unit_test): + super().__init__(unit_test, model=BasicModel, compute_hessian=False, run_verification=False) + + self.num_nodes = 1 + self.num_scores = 5 + + def run_test(self, seed=0): + self.graph = prepare_graph_with_configs(self.float_model, + self.pytorch_impl, + DEFAULT_PYTORCH_INFO, + representative_dataset, + generate_pytorch_tpc) + + self.request = TraceHessianRequest(mode=HessianMode.ACTIVATION, + granularity=HessianInfoGranularity.PER_TENSOR, + target_nodes=[list(self.graph.get_topo_sorted_nodes())[0]]) + + super().run_test() + + with self.unit_test.assertRaises(Exception) as e: + hessian = self.hessian_service.fetch_hessian(self.request, self.num_scores, + batch_size=1) # representative dataset produces 4 images total + + self.unit_test.assertTrue('Not enough samples in the provided representative dataset' in str(e.exception)) + + +class FetchComputeBatchLargerThanReprBatchTest(BaseHessianServiceTest): + def __init__(self, unit_test): + super().__init__(unit_test, model=BasicModel, compute_hessian=False, run_verification=False) + + self.num_nodes = 1 + self.num_scores = 3 + + def run_test(self, seed=0): + self.graph = prepare_graph_with_configs(self.float_model, + self.pytorch_impl, + DEFAULT_PYTORCH_INFO, + representative_dataset, + generate_pytorch_tpc) + + self.request = TraceHessianRequest(mode=HessianMode.ACTIVATION, + granularity=HessianInfoGranularity.PER_TENSOR, + target_nodes=[list(self.graph.get_topo_sorted_nodes())[0]]) + + super().run_test() + self.hessian = self.hessian_service.fetch_hessian(self.request, 3, batch_size=3) # representative batch size is 2 + super().verify_hessian() + + +class FetchHessianRequiredZeroTest(BaseHessianServiceTest): + def __init__(self, unit_test): + super().__init__(unit_test, model=BasicModel) + + self.num_nodes = 1 + self.num_scores = 0 + + def run_test(self, seed=0): + self.graph = prepare_graph_with_configs(self.float_model, + self.pytorch_impl, + DEFAULT_PYTORCH_INFO, + representative_dataset, + generate_pytorch_tpc) + + self.request = TraceHessianRequest(mode=HessianMode.ACTIVATION, + granularity=HessianInfoGranularity.PER_TENSOR, + target_nodes=[list(self.graph.get_topo_sorted_nodes())[0]]) + + super().run_test() + + +class FetchHessianMultipleNodesTest(BaseHessianServiceTest): + def __init__(self, unit_test): + super().__init__(unit_test, model=MultipleActNodesModel) + + self.num_nodes = 2 + self.num_scores = 2 + + def run_test(self, seed=0): + self.graph = prepare_graph_with_configs(self.float_model, + self.pytorch_impl, + DEFAULT_PYTORCH_INFO, + representative_dataset, + generate_pytorch_tpc) + + nodes = list(self.graph.get_topo_sorted_nodes()) + self.request = TraceHessianRequest(mode=HessianMode.ACTIVATION, + granularity=HessianInfoGranularity.PER_TENSOR, + target_nodes=[nodes[0], nodes[2]]) + + super().run_test() + + +class DoubleFetchHessianTest(BaseHessianServiceTest): + def __init__(self, unit_test): + super().__init__(unit_test, model=MultipleActNodesModel, compute_hessian=False, run_verification=False) + + self.num_nodes = 2 + self.num_scores = 2 + + def run_test(self, seed=0): + self.graph = prepare_graph_with_configs(self.float_model, + self.pytorch_impl, + DEFAULT_PYTORCH_INFO, + representative_dataset, + generate_pytorch_tpc) + + target_node = list(self.graph.get_topo_sorted_nodes())[0] + self.request = TraceHessianRequest(mode=HessianMode.ACTIVATION, + granularity=HessianInfoGranularity.PER_TENSOR, + target_nodes=[target_node]) + + super().run_test() + + hessian = self.hessian_service.fetch_hessian(self.request, 2) + self.unit_test.assertEqual(len(hessian), 1, "Expecting returned Hessian list to include one list of " + "approximation, for the single target node.") + self.unit_test.assertEqual(len(hessian[0]), 2, "Expecting 2 Hessian scores.") + self.unit_test.assertEqual(self.hessian_service.count_saved_info_of_request(self.request)[target_node], 2) + + hessian = self.hessian_service.fetch_hessian(self.request, 2) + self.unit_test.assertEqual(len(hessian), 1, "Expecting returned Hessian list to include one list of " + "approximation, for the single target node.") + self.unit_test.assertEqual(len(hessian[0]), 2, "Expecting 2 Hessian scores.") + self.unit_test.assertEqual(self.hessian_service.count_saved_info_of_request(self.request)[target_node], 2) diff --git a/tests/pytorch_tests/function_tests/test_sensitivity_eval_non_supported_output.py b/tests/pytorch_tests/function_tests/test_sensitivity_eval_non_supported_output.py index 4e35a60bc..e42f28308 100644 --- a/tests/pytorch_tests/function_tests/test_sensitivity_eval_non_supported_output.py +++ b/tests/pytorch_tests/function_tests/test_sensitivity_eval_non_supported_output.py @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================== import torch +import numpy as np +from model_compression_toolkit.constants import MP_DEFAULT_NUM_SAMPLES from model_compression_toolkit.core import MixedPrecisionQuantizationConfig from model_compression_toolkit.core.common.hessian import HessianInfoService from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO @@ -44,7 +46,10 @@ class TestSensitivityEvalWithNonSupportedOutputBase(BasePytorchTest): def create_inputs_shape(self): return [[1, 3, 16, 16]] - def representative_data_gen(self, n_iters=1): + def generate_inputs(self, input_shapes): + return [np.random.randn(*in_shape) for in_shape in input_shapes] + + def representative_data_gen(self, n_iters=MP_DEFAULT_NUM_SAMPLES): input_shapes = self.create_inputs_shape() for _ in range(n_iters): yield self.generate_inputs(input_shapes) @@ -61,9 +66,8 @@ def verify_test_for_model(self, model): generate_pytorch_tpc, input_shape=(1, 3, 16, 16), mixed_precision_enabled=True) - hessian_info_service = HessianInfoService(graph=graph, - fw_impl=pytorch_impl, - representative_dataset=self.representative_data_gen) + hessian_info_service = HessianInfoService(graph=graph, representative_dataset_gen=self.representative_data_gen, + fw_impl=pytorch_impl) se = pytorch_impl.get_sensitivity_evaluator(graph, MixedPrecisionQuantizationConfig(use_hessian_based_scores=True), diff --git a/tests/pytorch_tests/model_tests/feature_models/gptq_test.py b/tests/pytorch_tests/model_tests/feature_models/gptq_test.py index d6e18befc..7561d7afa 100644 --- a/tests/pytorch_tests/model_tests/feature_models/gptq_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/gptq_test.py @@ -20,6 +20,7 @@ import mct_quantizers from model_compression_toolkit import DefaultDict +from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR, MAX_LSB_STR from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest @@ -55,8 +56,9 @@ def forward(self, inp): class GPTQBaseTest(BasePytorchFeatureNetworkTest): def __init__(self, unit_test, weights_bits=8, weights_quant_method=QuantizationMethod.SYMMETRIC, rounding_type=RoundingType.STE, per_channel=True, - hessian_weights=True, log_norm_weights=True, scaled_log_norm=False, params_learning=True): - super().__init__(unit_test, input_shape=(3, 16, 16)) + hessian_weights=True, log_norm_weights=True, scaled_log_norm=False, params_learning=True, + num_calibration_iter=GPTQ_HESSIAN_NUM_SAMPLES): + super().__init__(unit_test, input_shape=(3, 16, 16), num_calibration_iter=num_calibration_iter) self.seed = 0 self.rounding_type = rounding_type self.weights_bits = weights_bits @@ -125,16 +127,6 @@ def get_gptq_config(self): scale_log_norm=self.scaled_log_norm), gptq_quantizer_params_override=self.override_params) - def get_gptq_config(self): - return GradientPTQConfig(5, optimizer=torch.optim.Adam([torch.Tensor([])], lr=1e-4), - optimizer_rest=torch.optim.Adam([torch.Tensor([])], lr=1e-4), - loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, - use_hessian_based_weights=self.hessian_weights, - optimizer_bias=torch.optim.Adam([torch.Tensor([])], lr=0.4), - hessian_weights_config=GPTQHessianScoresConfig(log_norm=self.log_norm_weights, - scale_log_norm=self.scaled_log_norm), - gptq_quantizer_params_override=self.override_params) - def gptq_compare(self, ptq_model, gptq_model, input_x=None): ptq_weights = torch_tensor_to_numpy(list(ptq_model.parameters())) gptq_weights = torch_tensor_to_numpy(list(gptq_model.parameters())) @@ -144,12 +136,6 @@ def gptq_compare(self, ptq_model, gptq_model, input_x=None): class GPTQWeightsUpdateTest(GPTQBaseTest): - def get_gptq_config(self): - return GradientPTQConfig(50, optimizer=torch.optim.Adam([torch.Tensor([])], lr=0.5), - optimizer_rest=torch.optim.Adam([torch.Tensor([])], lr=0.5), - loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, - gptq_quantizer_params_override=self.override_params) - def get_gptq_config(self): return GradientPTQConfig(50, optimizer=torch.optim.Adam([torch.Tensor([])], lr=0.5), optimizer_rest=torch.optim.Adam([torch.Tensor([])], lr=0.5), @@ -171,12 +157,6 @@ def compare(self, ptq_model, gptq_model, input_x=None, max_change=None): class GPTQLearnRateZeroTest(GPTQBaseTest): - def get_gptq_config(self): - return GradientPTQConfig(5, optimizer=torch.optim.Adam([torch.Tensor([])], lr=0), - optimizer_rest=torch.optim.Adam([torch.Tensor([])], lr=0), - loss=multiple_tensors_mse_loss, train_bias=False, rounding_type=self.rounding_type, - gptq_quantizer_params_override=self.override_params) - def get_gptq_config(self): return GradientPTQConfig(5, optimizer=torch.optim.Adam([torch.Tensor([])], lr=0), optimizer_rest=torch.optim.Adam([torch.Tensor([])], lr=0), diff --git a/tests/pytorch_tests/pruning_tests/feature_networks/pruning_pytorch_feature_test.py b/tests/pytorch_tests/pruning_tests/feature_networks/pruning_pytorch_feature_test.py index 90ded1040..657995f1e 100644 --- a/tests/pytorch_tests/pruning_tests/feature_networks/pruning_pytorch_feature_test.py +++ b/tests/pytorch_tests/pruning_tests/feature_networks/pruning_pytorch_feature_test.py @@ -29,7 +29,7 @@ class PruningPytorchFeatureTest(BasePytorchFeatureNetworkTest): def __init__(self, unit_test, - num_calibration_iter=1, + num_calibration_iter=2, val_batch_size=1, num_of_inputs=1, input_shape=(3, 8, 8)):