Skip to content

Commit

Permalink
Activation Hessian computation runtime optimization (#1092)
Browse files Browse the repository at this point in the history
Improve Activation Hessian computation runtime for GPTQ and Mixed precision with the following optimizations:

- Enable batch computation.
- Enable computation on a set of nodes (instead of a single node only).
- Other minor loop and implementation modifications.

---------

Co-authored-by: Ofir Gordon <Ofir.Gordon@altair-semi.com>
  • Loading branch information
ofirgo and Ofir Gordon authored Jun 5, 2024
1 parent 3fcf49d commit d13319f
Show file tree
Hide file tree
Showing 36 changed files with 1,210 additions and 577 deletions.
5 changes: 4 additions & 1 deletion model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,11 @@

# Hessian configuration default constants
HESSIAN_OUTPUT_ALPHA = 0.3
HESSIAN_NUM_ITERATIONS = 50
HESSIAN_NUM_ITERATIONS = 100
HESSIAN_EPS = 1e-6
ACT_HESSIAN_DEFAULT_BATCH_SIZE = 32
GPTQ_HESSIAN_NUM_SAMPLES = 32
MP_DEFAULT_NUM_SAMPLES = 32

# Pruning constants
PRUNING_NUM_SCORE_APPROXIMATIONS = 32
301 changes: 235 additions & 66 deletions model_compression_toolkit/core/common/hessian/hessian_info_service.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@
from model_compression_toolkit.constants import EPS


def normalize_scores(hessian_approximations: List) -> np.ndarray:
def normalize_scores(hessian_approximations: List) -> List[np.ndarray]:
"""
Normalize Hessian information approximations by dividing the trace Hessian approximations value by the sum of all
other values.
Args:
hessian_approximations: Approximated average Hessian-based scores for each interest point.
hessian_approximations: Approximated Hessian-based scores for each image for each interest point.
Returns:
Normalized list of Hessian info approximations for each interest point.
"""
scores_vec = np.asarray(hessian_approximations)
scores_vec = np.asarray(hessian_approximations) # Images x Nodes X Scores
norm_scores_per_image = scores_vec / (np.sum(scores_vec, axis=1, keepdims=True) + EPS)

return scores_vec / (np.sum(scores_vec) + EPS)
return [norm_scores_per_image[i] for i in range(norm_scores_per_image.shape[0])]

Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,9 @@ def __init__(self,
self.num_iterations_for_approximation = num_iterations_for_approximation

# Validate representative dataset has same inputs as graph
if len(self.input_images)!=len(graph.get_inputs()):
if len(self.input_images) != len(graph.get_inputs()): # pragma: no cover
Logger.critical(f"The graph requires {len(graph.get_inputs())} inputs, but the provided representative dataset contains {len(self.input_images)} inputs.")

# Assert all inputs have a batch size of 1
for image in self.input_images:
if image.shape[0]!=1:
Logger.critical(f"Hessian calculations are restricted to a single-image per input. Found input with shape: {image.shape}.")

self.fw_impl = fw_impl
self.hessian_request = trace_hessian_request

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,28 +52,28 @@ class TraceHessianRequest:
def __init__(self,
mode: HessianMode,
granularity: HessianInfoGranularity,
target_node,
target_nodes: List,
):
"""
Attributes:
mode (HessianMode): Mode of Hessian's trace approximation (w.r.t weights or activations).
granularity (HessianInfoGranularity): Granularity level for the approximation.
target_node (BaseNode): The node in the float graph for which the Hessian's trace approximation is targeted.
target_nodes (List[BaseNode]): The node in the float graph for which the Hessian's trace approximation is targeted.
"""

self.mode = mode # w.r.t activations or weights
self.granularity = granularity # per element, per layer, per channel
self.target_node = target_node # TODO: extend it list of nodes
self.target_nodes = target_nodes

def __eq__(self, other):
# Checks if the other object is an instance of TraceHessianRequest
# and then checks if all attributes are equal.
return isinstance(other, TraceHessianRequest) and \
self.mode == other.mode and \
self.granularity == other.granularity and \
self.target_node == other.target_node
self.target_nodes == other.target_nodes

def __hash__(self):
# Computes the hash based on the attributes.
# The use of a tuple here ensures that the hash is influenced by all the attributes.
return hash((self.mode, self.granularity, self.target_node))
return hash((self.mode, self.granularity, tuple(self.target_nodes)))
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from typing import List, Callable

from model_compression_toolkit.constants import MP_DEFAULT_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting


Expand All @@ -23,13 +24,14 @@ class MixedPrecisionQuantizationConfig:
def __init__(self,
compute_distance_fn: Callable = None,
distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG,
num_of_images: int = 32,
num_of_images: int = MP_DEFAULT_NUM_SAMPLES,
configuration_overwrite: List[int] = None,
num_interest_points_factor: float = 1.0,
use_hessian_based_scores: bool = False,
norm_scores: bool = True,
refine_mp_solution: bool = True,
metric_normalization_threshold: float = 1e10):
metric_normalization_threshold: float = 1e10,
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE):
"""
Class with mixed precision parameters to quantize the input model.
Expand All @@ -43,6 +45,7 @@ def __init__(self,
norm_scores (bool): Whether to normalize the returned scores for the weighted distance metric (to get values between 0 and 1).
refine_mp_solution (bool): Whether to try to improve the final mixed-precision configuration using a greedy algorithm that searches layers to increase their bit-width, or not.
metric_normalization_threshold (float): A threshold for checking the mixed precision distance metric values, In case of values larger than this threshold, the metric will be scaled to prevent numerical issues.
hessian_batch_size (int): The Hessian computation batch size. used only if using mixed precision with Hessian-based objective.
"""

Expand All @@ -60,6 +63,7 @@ def __init__(self,

self.use_hessian_based_scores = use_hessian_based_scores
self.norm_scores = norm_scores
self.hessian_batch_size = hessian_batch_size

self.metric_normalization_threshold = metric_normalization_threshold

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, \
HessianInfoGranularity, HessianInfoService
from model_compression_toolkit.core.common.hessian import hessian_info_utils as hessian_utils


class SensitivityEvaluation:
Expand Down Expand Up @@ -238,47 +237,24 @@ def _compute_hessian_based_scores(self) -> np.ndarray:
to be used for the distance metric weighted average computation.
"""
# Dictionary to store the trace Hessian approximations for each interest point (target node)
compare_point_to_trace_hessian_approximations = {}

# Iterate over each interest point to fetch the trace Hessian approximations
for target_node in self.interest_points:
# Create a request for trace Hessian approximation with specific configurations
# (here we use per-tensor approximation of the Hessian's trace w.r.t the node's activations)
trace_hessian_request = TraceHessianRequest(mode=HessianMode.ACTIVATION,
granularity=HessianInfoGranularity.PER_TENSOR,
target_node=target_node)

# Fetch the trace Hessian approximations for the current interest point
node_approximations = self.hessian_info_service.fetch_hessian(trace_hessian_request=trace_hessian_request,
required_size=self.quant_config.num_of_images)
# Store the fetched approximations in the dictionary
compare_point_to_trace_hessian_approximations[target_node] = node_approximations

# List to store the approximations for each image
approx_by_image = []
# Iterate over each image
for image_idx in range(self.quant_config.num_of_images):
# List to store approximations for the current image for each interest point
approx_by_image_per_interest_point = []
# Iterate over each interest point to gather approximations
for target_node in self.interest_points:
# Ensure the approximation for the current interest point and image is a list
assert isinstance(compare_point_to_trace_hessian_approximations[target_node][image_idx], list)
# Ensure the approximation list contains only one element (since, granularity is per-tensor)
assert len(compare_point_to_trace_hessian_approximations[target_node][image_idx]) == 1
# Append the single approximation value to the list for the current image
approx_by_image_per_interest_point.append(compare_point_to_trace_hessian_approximations[target_node][image_idx][0])

if self.quant_config.norm_scores:
approx_by_image_per_interest_point = \
hessian_utils.normalize_scores(hessian_approximations=approx_by_image_per_interest_point)

# Append the approximations for the current image to the main list
approx_by_image.append(approx_by_image_per_interest_point)
# Create a request for trace Hessian approximation with specific configurations
# (here we use per-tensor approximation of the Hessian's trace w.r.t the node's activations)
trace_hessian_request = TraceHessianRequest(mode=HessianMode.ACTIVATION,
granularity=HessianInfoGranularity.PER_TENSOR,
target_nodes=self.interest_points)

# Fetch the trace Hessian approximations for the current interest point
nodes_approximations = self.hessian_info_service.fetch_hessian(trace_hessian_request=trace_hessian_request,
required_size=self.quant_config.num_of_images,
batch_size=self.quant_config.hessian_batch_size)

# Store the approximations for each node for each image
approx_by_image = [[nodes_approximations[j][image_idx]
for j, _ in enumerate(self.interest_points)]
for image_idx in range(self.quant_config.num_of_images)]

# Return the mean approximation value across all images for each interest point
return np.mean(approx_by_image, axis=0)
return np.mean(np.stack(approx_by_image), axis=0)

def _configure_bitwidths_model(self,
mp_model_configuration: List[int],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,21 @@ def _get_entry_node_to_score(self, entry_nodes: List[BaseNode]) -> Dict[BaseNode

# Initialize HessianInfoService for score computation.
hessian_info_service = HessianInfoService(graph=self.float_graph,
representative_dataset=self.representative_data_gen,
representative_dataset_gen=self.representative_data_gen,
fw_impl=self.fw_impl)

# Fetch and process Hessian scores for output channels of entry nodes.
nodes_scores = []
for node in entry_nodes:
_request = TraceHessianRequest(mode=HessianMode.WEIGHTS,
granularity=HessianInfoGranularity.PER_OUTPUT_CHANNEL,
target_node=node)
target_nodes=[node])
_scores_for_node = hessian_info_service.fetch_hessian(_request,
required_size=self.pruning_config.num_score_approximations)
nodes_scores.append(_scores_for_node)

# Average and map scores to nodes.
self._entry_node_to_hessian_score = {node: np.mean(scores, axis=0) for node, scores in zip(entry_nodes, nodes_scores)}
self._entry_node_to_hessian_score = {node: np.mean(scores[0], axis=0) for node, scores in zip(entry_nodes, nodes_scores)}

self._entry_node_count_oc_nparams = self._count_oc_nparams(entry_nodes=entry_nodes)
_entry_node_l2_oc_norm = self._get_squaredl2norm(entry_nodes=entry_nodes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def _compute_hessian_for_hmse(node,
"""
_request = TraceHessianRequest(mode=HessianMode.WEIGHTS,
granularity=HessianInfoGranularity.PER_ELEMENT,
target_node=node)
target_nodes=[node])
_scores_for_node = hessian_info_service.fetch_hessian(_request,
required_size=num_hessian_samples)

Expand Down
Loading

0 comments on commit d13319f

Please sign in to comment.