Skip to content

Commit

Permalink
Fix hessian refactor issues (#829)
Browse files Browse the repository at this point in the history
Fixes to the hessian refactor #823

- Mixed precision: Added missing scores normalization.
- GPTQ: Added missing alpha param (set to 0).
- Modified logs to be more informative.
- Remove commented-out code.
- Add framework implementation method to create a dataset with a batch size of 1 (instead if inside the HessianServiceInfo).

---------

Co-authored-by: reuvenp <reuvenp@altair-semi.com>
  • Loading branch information
reuvenperetz and reuvenp authored Oct 15, 2023
1 parent 90721d0 commit fb5917a
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 68 deletions.
14 changes: 14 additions & 0 deletions model_compression_toolkit/core/common/framework_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,20 @@ def get_trace_hessian_calculator(self,
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_trace_hessian_calculator method.') # pragma: no cover

@abstractmethod
def sample_single_representative_dataset(self, representative_dataset: Callable):
"""
Get a single sample (namely, batch size of 1) from a representative dataset.
Args:
representative_dataset: Callable which returns the representative dataset at any batch size.
Returns: List of inputs from representative_dataset where each sample has a batch size of 1.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s sample_single_representative_dataset method.') # pragma: no cover

@abstractmethod
def to_numpy(self, tensor: Any) -> np.ndarray:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian.trace_hessian_request import TraceHessianRequest
from model_compression_toolkit.logger import Logger
from functools import partial


class HessianInfoService:
Expand Down Expand Up @@ -51,7 +52,11 @@ def __init__(self,
fw_impl: Framework-specific implementation for trace Hessian approximation computation.
"""
self.graph = graph
self.representative_dataset = representative_dataset

# Create a representative_data_gen with batch size of 1
self.representative_dataset = partial(fw_impl.sample_single_representative_dataset,
representative_dataset=representative_dataset)

self.fw_impl = fw_impl
self.num_iterations_for_approximation = num_iterations_for_approximation

Expand All @@ -76,19 +81,6 @@ def count_saved_info_of_request(self, hessian_request:TraceHessianRequest) -> in
# Check if the request is in the saved info and return its count, otherwise return 0
return len(self.trace_hessian_request_to_score_list.get(hessian_request, []))

def _sample_single_image_per_input(self) -> List[Any]:
"""
Samples a single image per input from the representative dataset.
Returns:
List: List of sampled images.
"""
images = next(self.representative_dataset())
if not isinstance(images, list):
Logger.error(f'Images expected to be a list but is of type {type(images)}')

# Ensure each image is a single sample, if not, take the first sample
return [np.expand_dims(image[0], 0) if image.shape[0] != 1 else image for image in images]

def compute(self, trace_hessian_request:TraceHessianRequest):
"""
Expand All @@ -98,10 +90,10 @@ def compute(self, trace_hessian_request:TraceHessianRequest):
Args:
trace_hessian_request: Configuration for which to compute the approximation.
"""
Logger.info(f"Computing Hessian-trace approximation for a sample.")
Logger.debug(f"Computing Hessian-trace approximation for a node {trace_hessian_request.target_node}.")

# Sample images for the computation
images = self._sample_single_image_per_input()
images = self.representative_dataset()

# Get the framework-specific calculator for trace Hessian approximation
fw_hessian_calculator = self.fw_impl.get_trace_hessian_calculator(graph=self.graph,
Expand Down Expand Up @@ -137,6 +129,8 @@ def fetch_hessian(self,
The inner list length dependent on the granularity (1 for per-tensor,
OC for per-output-channel when the requested node has OC output-channels, etc.)
"""
Logger.info(f"Ensuring {required_size} Hessian-trace approximation for node {trace_hessian_request.target_node}.")

# Ensure the saved info has the required number of approximations
self._populate_saved_info_to_size(trace_hessian_request, required_size)

Expand All @@ -157,6 +151,10 @@ def _populate_saved_info_to_size(self,
# Get the current number of saved approximations for the request
current_existing_hessians = self.count_saved_info_of_request(trace_hessian_request)

Logger.info(
f"Found {current_existing_hessians} Hessian-trace approximations for node {trace_hessian_request.target_node}."
f" {required_size - current_existing_hessians} approximations left to compute...")

# Compute the required number of approximations to meet the required size
for _ in range(required_size - current_existing_hessians):
self.compute(trace_hessian_request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def __init__(self,
if len(self.input_images)!=len(graph.get_inputs()):
Logger.error(f"Graph has {len(graph.get_inputs())} inputs, but provided representative dataset returns {len(self.input_images)} inputs")

# Assert all inputs have a batch size of 1
for image in self.input_images:
if image.shape[0]!=1:
Logger.error(f"Hessian is calculated only for a single image (per input) but input shape is {image.shape}")

self.fw_impl = fw_impl
self.hessian_request = trace_hessian_request

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
from typing import Callable, Any, List

from model_compression_toolkit.constants import AXIS
from model_compression_toolkit.constants import AXIS, HESSIAN_OUTPUT_ALPHA
from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
Expand All @@ -26,6 +26,7 @@
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, \
HessianInfoGranularity, HessianInfoService
from model_compression_toolkit.core.common.hessian import hessian_info_utils as hessian_utils


class SensitivityEvaluation:
Expand Down Expand Up @@ -236,6 +237,12 @@ def _compute_gradient_based_weights(self) -> np.ndarray:
assert len(compare_point_to_trace_hessian_approximations[target_node][image_idx]) == 1
# Append the single approximation value to the list for the current image
approx_by_image_per_interest_point.append(compare_point_to_trace_hessian_approximations[target_node][image_idx][0])

if self.quant_config.norm_weights:
approx_by_image_per_interest_point = hessian_utils.normalize_weights(trace_hessian_approximations=approx_by_image_per_interest_point,
outputs_indices=self.output_nodes_indices,
alpha=HESSIAN_OUTPUT_ALPHA)

# Append the approximations for the current image to the main list
approx_by_image.append(approx_by_image_per_interest_point)

Expand Down
16 changes: 16 additions & 0 deletions model_compression_toolkit/core/keras/keras_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,3 +591,19 @@ def sensitivity_eval_inference(self,
"""

return model(inputs)

def sample_single_representative_dataset(self, representative_dataset: Callable):
"""
Get a single sample (namely, batch size of 1) from a representative dataset.
Args:
representative_dataset: Callable which returns the representative dataset at any batch size.
Returns: List of inputs from representative_dataset where each sample has a batch size of 1.
"""
images = next(representative_dataset())
if not isinstance(images, list):
Logger.error(f'Images expected to be a list but is of type {type(images)}')

# Ensure each image is a single sample, if not, take the first sample
return [tf.expand_dims(image[0], 0) if image.shape[0] != 1 else image for image in images]
17 changes: 17 additions & 0 deletions model_compression_toolkit/core/pytorch/pytorch_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from model_compression_toolkit.core.pytorch.statistics_correction.apply_second_moment_correction import \
pytorch_apply_second_moment_correction
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model
from model_compression_toolkit.logger import Logger


class PytorchImplementation(FrameworkImplementation):
Expand Down Expand Up @@ -538,3 +539,19 @@ def get_trace_hessian_calculator(self,
input_images=input_images,
fw_impl=self,
num_iterations_for_approximation=num_iterations_for_approximation)

def sample_single_representative_dataset(self, representative_dataset: Callable):
"""
Get a single sample (namely, batch size of 1) from a representative dataset.
Args:
representative_dataset: Callable which returns the representative dataset at any batch size.
Returns: List of inputs from representative_dataset where each sample has a batch size of 1.
"""
images = next(representative_dataset())
if not isinstance(images, list):
Logger.error(f'Images expected to be a list but is of type {type(images)}')

# Ensure each image is a single sample, if not, take the first sample
return [torch.unsqueeze(image[0], 0) if image.shape[0] != 1 else image for image in images]
52 changes: 1 addition & 51 deletions model_compression_toolkit/gptq/common/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _process_hessian_approximations(self, approximations: Dict[BaseNode, List[Li
for image_idx in range(self.gptq_config.hessian_weights_config.hessians_num_samples):
approx_by_interest_point = self._get_approximations_by_interest_point(approximations, image_idx)
if self.gptq_config.hessian_weights_config.norm_weights:
approx_by_interest_point = hessian_utils.normalize_weights(approx_by_interest_point, [])
approx_by_interest_point = hessian_utils.normalize_weights(approx_by_interest_point, [], alpha=0)
trace_hessian_approx_by_image.append(approx_by_interest_point)
return trace_hessian_approx_by_image

Expand Down Expand Up @@ -243,56 +243,6 @@ def _validate_trace_approximation(trace_approx: List):
f"granularity=HessianInfoGranularity.PER_TENSOR) but has a length of {len(trace_approx)}"
)

# def compute_hessian_based_weights(self) -> np.ndarray:
# """
#
# Returns: Trace hessian approximations per layer w.r.t activations of the interest points.
#
# """
# # TODO: Add comments + rewrtie the loops for better clarity
# if self.gptq_config.use_hessian_based_weights:
# compare_point_to_trace_hessian_approximations = {}
# for target_node in self.compare_points:
# trace_hessian_request = TraceHessianRequest(mode=HessianMode.ACTIVATION,
# granularity=HessianInfoGranularity.PER_TENSOR,
# target_node=target_node)
# node_approximations = self.hessian_service.fetch_hessian(trace_hessian_request=trace_hessian_request,
# required_size=self.gptq_config.hessian_weights_config.hessians_num_samples)
# compare_point_to_trace_hessian_approximations[target_node] = node_approximations
#
# trace_hessian_approx_by_image = []
# for image_idx in range(self.gptq_config.hessian_weights_config.hessians_num_samples):
# approx_by_interest_point = []
# for target_node in self.compare_points:
# if not isinstance(compare_point_to_trace_hessian_approximations[target_node][image_idx], list):
# Logger.error(f"Trace approx was expected to be a list but has a type of {type(compare_point_to_trace_hessian_approximations[target_node][image_idx])}")
# if not len(compare_point_to_trace_hessian_approximations[target_node][image_idx])==1:
# Logger.error(
# f"Trace approx was expected to be of length 1 (when computing approximations with "
# f"granularity=HessianInfoGranularity.PER_TENSOR but has a length of {len(compare_point_to_trace_hessian_approximations[target_node][image_idx])}")
# approx_by_interest_point.append(compare_point_to_trace_hessian_approximations[target_node][image_idx][0])
#
# if self.gptq_config.hessian_weights_config.norm_weights:
# approx_by_interest_point = hessian_utils.normalize_weights(approx_by_interest_point,
# [])
# trace_hessian_approx_by_image.append(approx_by_interest_point)
#
# if self.gptq_config.hessian_weights_config.log_norm:
# mean_approx_scores = np.mean(trace_hessian_approx_by_image, axis=0)
# mean_approx_scores = np.where(mean_approx_scores != 0, mean_approx_scores,
# np.partition(mean_approx_scores, 1)[1])
# log_weights = np.log10(mean_approx_scores)
#
# if self.gptq_config.hessian_weights_config.scale_log_norm:
# return (log_weights - np.min(log_weights)) / (np.max(log_weights) - np.min(log_weights))
#
# return log_weights - np.min(log_weights)
# else:
# return np.mean(trace_hessian_approx_by_image, axis=0)
# else:
# num_nodes = len(self.compare_points)
# return np.asarray([1 / num_nodes for _ in range(num_nodes)])

@staticmethod
def _generate_images_batch(representative_data_gen: Callable, num_samples_for_loss: int) -> np.ndarray:
"""
Expand Down

0 comments on commit fb5917a

Please sign in to comment.