diff --git a/model_compression_toolkit/core/common/framework_implementation.py b/model_compression_toolkit/core/common/framework_implementation.py index 119bd242b..ee68f5ab5 100644 --- a/model_compression_toolkit/core/common/framework_implementation.py +++ b/model_compression_toolkit/core/common/framework_implementation.py @@ -66,6 +66,20 @@ def get_trace_hessian_calculator(self, """ raise NotImplemented(f'{self.__class__.__name__} have to implement the ' f'framework\'s get_trace_hessian_calculator method.') # pragma: no cover + + @abstractmethod + def sample_single_representative_dataset(self, representative_dataset: Callable): + """ + Get a single sample (namely, batch size of 1) from a representative dataset. + + Args: + representative_dataset: Callable which returns the representative dataset at any batch size. + + Returns: List of inputs from representative_dataset where each sample has a batch size of 1. + """ + raise NotImplemented(f'{self.__class__.__name__} have to implement the ' + f'framework\'s sample_single_representative_dataset method.') # pragma: no cover + @abstractmethod def to_numpy(self, tensor: Any) -> np.ndarray: """ diff --git a/model_compression_toolkit/core/common/hessian/hessian_info_service.py b/model_compression_toolkit/core/common/hessian/hessian_info_service.py index 3a5551c08..ce3f72d7c 100644 --- a/model_compression_toolkit/core/common/hessian/hessian_info_service.py +++ b/model_compression_toolkit/core/common/hessian/hessian_info_service.py @@ -20,6 +20,7 @@ from model_compression_toolkit.core.common import Graph from model_compression_toolkit.core.common.hessian.trace_hessian_request import TraceHessianRequest from model_compression_toolkit.logger import Logger +from functools import partial class HessianInfoService: @@ -51,7 +52,11 @@ def __init__(self, fw_impl: Framework-specific implementation for trace Hessian approximation computation. """ self.graph = graph - self.representative_dataset = representative_dataset + + # Create a representative_data_gen with batch size of 1 + self.representative_dataset = partial(fw_impl.sample_single_representative_dataset, + representative_dataset=representative_dataset) + self.fw_impl = fw_impl self.num_iterations_for_approximation = num_iterations_for_approximation @@ -76,19 +81,6 @@ def count_saved_info_of_request(self, hessian_request:TraceHessianRequest) -> in # Check if the request is in the saved info and return its count, otherwise return 0 return len(self.trace_hessian_request_to_score_list.get(hessian_request, [])) - def _sample_single_image_per_input(self) -> List[Any]: - """ - Samples a single image per input from the representative dataset. - - Returns: - List: List of sampled images. - """ - images = next(self.representative_dataset()) - if not isinstance(images, list): - Logger.error(f'Images expected to be a list but is of type {type(images)}') - - # Ensure each image is a single sample, if not, take the first sample - return [np.expand_dims(image[0], 0) if image.shape[0] != 1 else image for image in images] def compute(self, trace_hessian_request:TraceHessianRequest): """ @@ -98,10 +90,10 @@ def compute(self, trace_hessian_request:TraceHessianRequest): Args: trace_hessian_request: Configuration for which to compute the approximation. """ - Logger.info(f"Computing Hessian-trace approximation for a sample.") + Logger.debug(f"Computing Hessian-trace approximation for a node {trace_hessian_request.target_node}.") # Sample images for the computation - images = self._sample_single_image_per_input() + images = self.representative_dataset() # Get the framework-specific calculator for trace Hessian approximation fw_hessian_calculator = self.fw_impl.get_trace_hessian_calculator(graph=self.graph, @@ -137,6 +129,8 @@ def fetch_hessian(self, The inner list length dependent on the granularity (1 for per-tensor, OC for per-output-channel when the requested node has OC output-channels, etc.) """ + Logger.info(f"Ensuring {required_size} Hessian-trace approximation for node {trace_hessian_request.target_node}.") + # Ensure the saved info has the required number of approximations self._populate_saved_info_to_size(trace_hessian_request, required_size) @@ -157,6 +151,10 @@ def _populate_saved_info_to_size(self, # Get the current number of saved approximations for the request current_existing_hessians = self.count_saved_info_of_request(trace_hessian_request) + Logger.info( + f"Found {current_existing_hessians} Hessian-trace approximations for node {trace_hessian_request.target_node}." + f" {required_size - current_existing_hessians} approximations left to compute...") + # Compute the required number of approximations to meet the required size for _ in range(required_size - current_existing_hessians): self.compute(trace_hessian_request) diff --git a/model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py b/model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py index 4f575ea76..8d8e79b40 100644 --- a/model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py +++ b/model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py @@ -58,6 +58,11 @@ def __init__(self, if len(self.input_images)!=len(graph.get_inputs()): Logger.error(f"Graph has {len(graph.get_inputs())} inputs, but provided representative dataset returns {len(self.input_images)} inputs") + # Assert all inputs have a batch size of 1 + for image in self.input_images: + if image.shape[0]!=1: + Logger.error(f"Hessian is calculated only for a single image (per input) but input shape is {image.shape}") + self.fw_impl = fw_impl self.hessian_request = trace_hessian_request diff --git a/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py b/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py index 742ce9701..f35c446a6 100644 --- a/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +++ b/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py @@ -17,7 +17,7 @@ import numpy as np from typing import Callable, Any, List -from model_compression_toolkit.constants import AXIS +from model_compression_toolkit.constants import AXIS, HESSIAN_OUTPUT_ALPHA from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfigV2 from model_compression_toolkit.core.common import Graph, BaseNode from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode @@ -26,6 +26,7 @@ from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, \ HessianInfoGranularity, HessianInfoService +from model_compression_toolkit.core.common.hessian import hessian_info_utils as hessian_utils class SensitivityEvaluation: @@ -236,6 +237,12 @@ def _compute_gradient_based_weights(self) -> np.ndarray: assert len(compare_point_to_trace_hessian_approximations[target_node][image_idx]) == 1 # Append the single approximation value to the list for the current image approx_by_image_per_interest_point.append(compare_point_to_trace_hessian_approximations[target_node][image_idx][0]) + + if self.quant_config.norm_weights: + approx_by_image_per_interest_point = hessian_utils.normalize_weights(trace_hessian_approximations=approx_by_image_per_interest_point, + outputs_indices=self.output_nodes_indices, + alpha=HESSIAN_OUTPUT_ALPHA) + # Append the approximations for the current image to the main list approx_by_image.append(approx_by_image_per_interest_point) diff --git a/model_compression_toolkit/core/keras/keras_implementation.py b/model_compression_toolkit/core/keras/keras_implementation.py index acf0a3b02..12c87379f 100644 --- a/model_compression_toolkit/core/keras/keras_implementation.py +++ b/model_compression_toolkit/core/keras/keras_implementation.py @@ -591,3 +591,19 @@ def sensitivity_eval_inference(self, """ return model(inputs) + + def sample_single_representative_dataset(self, representative_dataset: Callable): + """ + Get a single sample (namely, batch size of 1) from a representative dataset. + + Args: + representative_dataset: Callable which returns the representative dataset at any batch size. + + Returns: List of inputs from representative_dataset where each sample has a batch size of 1. + """ + images = next(representative_dataset()) + if not isinstance(images, list): + Logger.error(f'Images expected to be a list but is of type {type(images)}') + + # Ensure each image is a single sample, if not, take the first sample + return [tf.expand_dims(image[0], 0) if image.shape[0] != 1 else image for image in images] \ No newline at end of file diff --git a/model_compression_toolkit/core/pytorch/pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pytorch_implementation.py index 52c11b788..3c111263b 100644 --- a/model_compression_toolkit/core/pytorch/pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pytorch_implementation.py @@ -86,6 +86,7 @@ from model_compression_toolkit.core.pytorch.statistics_correction.apply_second_moment_correction import \ pytorch_apply_second_moment_correction from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model +from model_compression_toolkit.logger import Logger class PytorchImplementation(FrameworkImplementation): @@ -538,3 +539,19 @@ def get_trace_hessian_calculator(self, input_images=input_images, fw_impl=self, num_iterations_for_approximation=num_iterations_for_approximation) + + def sample_single_representative_dataset(self, representative_dataset: Callable): + """ + Get a single sample (namely, batch size of 1) from a representative dataset. + + Args: + representative_dataset: Callable which returns the representative dataset at any batch size. + + Returns: List of inputs from representative_dataset where each sample has a batch size of 1. + """ + images = next(representative_dataset()) + if not isinstance(images, list): + Logger.error(f'Images expected to be a list but is of type {type(images)}') + + # Ensure each image is a single sample, if not, take the first sample + return [torch.unsqueeze(image[0], 0) if image.shape[0] != 1 else image for image in images] \ No newline at end of file diff --git a/model_compression_toolkit/gptq/common/gptq_training.py b/model_compression_toolkit/gptq/common/gptq_training.py index 452376476..71ae33ad9 100644 --- a/model_compression_toolkit/gptq/common/gptq_training.py +++ b/model_compression_toolkit/gptq/common/gptq_training.py @@ -205,7 +205,7 @@ def _process_hessian_approximations(self, approximations: Dict[BaseNode, List[Li for image_idx in range(self.gptq_config.hessian_weights_config.hessians_num_samples): approx_by_interest_point = self._get_approximations_by_interest_point(approximations, image_idx) if self.gptq_config.hessian_weights_config.norm_weights: - approx_by_interest_point = hessian_utils.normalize_weights(approx_by_interest_point, []) + approx_by_interest_point = hessian_utils.normalize_weights(approx_by_interest_point, [], alpha=0) trace_hessian_approx_by_image.append(approx_by_interest_point) return trace_hessian_approx_by_image @@ -243,56 +243,6 @@ def _validate_trace_approximation(trace_approx: List): f"granularity=HessianInfoGranularity.PER_TENSOR) but has a length of {len(trace_approx)}" ) - # def compute_hessian_based_weights(self) -> np.ndarray: - # """ - # - # Returns: Trace hessian approximations per layer w.r.t activations of the interest points. - # - # """ - # # TODO: Add comments + rewrtie the loops for better clarity - # if self.gptq_config.use_hessian_based_weights: - # compare_point_to_trace_hessian_approximations = {} - # for target_node in self.compare_points: - # trace_hessian_request = TraceHessianRequest(mode=HessianMode.ACTIVATION, - # granularity=HessianInfoGranularity.PER_TENSOR, - # target_node=target_node) - # node_approximations = self.hessian_service.fetch_hessian(trace_hessian_request=trace_hessian_request, - # required_size=self.gptq_config.hessian_weights_config.hessians_num_samples) - # compare_point_to_trace_hessian_approximations[target_node] = node_approximations - # - # trace_hessian_approx_by_image = [] - # for image_idx in range(self.gptq_config.hessian_weights_config.hessians_num_samples): - # approx_by_interest_point = [] - # for target_node in self.compare_points: - # if not isinstance(compare_point_to_trace_hessian_approximations[target_node][image_idx], list): - # Logger.error(f"Trace approx was expected to be a list but has a type of {type(compare_point_to_trace_hessian_approximations[target_node][image_idx])}") - # if not len(compare_point_to_trace_hessian_approximations[target_node][image_idx])==1: - # Logger.error( - # f"Trace approx was expected to be of length 1 (when computing approximations with " - # f"granularity=HessianInfoGranularity.PER_TENSOR but has a length of {len(compare_point_to_trace_hessian_approximations[target_node][image_idx])}") - # approx_by_interest_point.append(compare_point_to_trace_hessian_approximations[target_node][image_idx][0]) - # - # if self.gptq_config.hessian_weights_config.norm_weights: - # approx_by_interest_point = hessian_utils.normalize_weights(approx_by_interest_point, - # []) - # trace_hessian_approx_by_image.append(approx_by_interest_point) - # - # if self.gptq_config.hessian_weights_config.log_norm: - # mean_approx_scores = np.mean(trace_hessian_approx_by_image, axis=0) - # mean_approx_scores = np.where(mean_approx_scores != 0, mean_approx_scores, - # np.partition(mean_approx_scores, 1)[1]) - # log_weights = np.log10(mean_approx_scores) - # - # if self.gptq_config.hessian_weights_config.scale_log_norm: - # return (log_weights - np.min(log_weights)) / (np.max(log_weights) - np.min(log_weights)) - # - # return log_weights - np.min(log_weights) - # else: - # return np.mean(trace_hessian_approx_by_image, axis=0) - # else: - # num_nodes = len(self.compare_points) - # return np.asarray([1 / num_nodes for _ in range(num_nodes)]) - @staticmethod def _generate_images_batch(representative_data_gen: Callable, num_samples_for_loss: int) -> np.ndarray: """