Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple nodes Weights Hessian computation #1102

Merged
merged 7 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from collections.abc import Iterable

import numpy as np
from functools import partial
Expand Down Expand Up @@ -189,6 +188,10 @@ def compute(self, trace_hessian_request: TraceHessianRequest, representative_dat
images, next_iter_remain_samples = representative_dataset_gen(num_hessian_samples=num_hessian_samples,
last_iter_remain_samples=last_iter_remain_samples)

# Compute and store the computed approximation in the saved info
topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()]
trace_hessian_request.target_nodes.sort(key=lambda x: topo_sorted_nodes_names.index(x.name))

# Get the framework-specific calculator for trace Hessian approximation
fw_hessian_calculator = self.fw_impl.get_trace_hessian_calculator(graph=self.graph,
input_images=images,
Expand All @@ -197,12 +200,7 @@ def compute(self, trace_hessian_request: TraceHessianRequest, representative_dat

trace_hessian = fw_hessian_calculator.compute()

# Store the computed approximation in the saved info
topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()]
sorted_target_nodes = sorted(trace_hessian_request.target_nodes,
key=lambda x: topo_sorted_nodes_names.index(x.name))

for node, hessian in zip(sorted_target_nodes, trace_hessian):
for node, hessian in zip(trace_hessian_request.target_nodes, trace_hessian):
single_node_request = self._construct_single_node_request(trace_hessian_request.mode,
trace_hessian_request.granularity,
node)
Expand Down Expand Up @@ -246,6 +244,10 @@ def fetch_hessian(self,
The inner list length dependent on the granularity (1 for per-tensor,
OC for per-output-channel when the requested node has OC output-channels, etc.)
"""

if len(trace_hessian_request.target_nodes) == 0:
return []

if required_size == 0:
return [[] for _ in trace_hessian_request.target_nodes]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, HessianInfoGranularity, \
HessianInfoService
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse, compute_mae, compute_lp_norm
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.constants import FLOAT_32, NUM_QPARAM_HESSIAN_SAMPLES
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor, \
Expand Down Expand Up @@ -376,7 +377,7 @@ def _get_sliced_histogram(bins: np.ndarray,

def _compute_hessian_for_hmse(node,
hessian_info_service: HessianInfoService,
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> List[np.ndarray]:
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> List[List[np.ndarray]]:
"""
Compute and retrieve Hessian-based scores for using during HMSE error computation.

Expand Down Expand Up @@ -476,7 +477,10 @@ def get_threshold_selection_tensor_error_function(quantization_method: Quantizat

if quant_error_method == qc.QuantizationErrorMethod.HMSE:
node_hessian_scores = _compute_hessian_for_hmse(node, hessian_info_service, num_hessian_samples)
node_hessian_scores = np.sqrt(np.mean(node_hessian_scores, axis=0))
if len(node_hessian_scores) != 1:
Logger.critical(f"Expecting single node Hessian score request to return a list of length 1, but got a list "
f"of length {len(node_hessian_scores)}.")
node_hessian_scores = np.sqrt(np.mean(node_hessian_scores[0], axis=0))
ofirgo marked this conversation as resolved.
Show resolved Hide resolved

return lambda x, y, threshold: _hmse_error_function_wrapper(x, y, norm=norm, axis=axis,
hessian_scores=node_hessian_scores)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,40 @@
from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
from model_compression_toolkit.core import QuantizationErrorMethod
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.hessian import HessianInfoService
from model_compression_toolkit.core.common.hessian import HessianInfoService, TraceHessianRequest, HessianMode, \
HessianInfoGranularity
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
import get_activations_qparams
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation import \
get_weights_qparams
from model_compression_toolkit.logger import Logger


def _collect_nodes_for_hmse(nodes_list: List[BaseNode], graph: Graph) -> List[BaseNode]:
"""
Collects nodes that are compatiable for parameters selection search using HMSE,
that is, have a kernel attribute that is configured for HMSE error method.

Args:
nodes_list: A list of nodes to search quantization parameters for.
graph: Graph to compute its nodes' quantization parameters..

Returns: A (possibly empty) list of nodes.

"""
hmse_nodes = []
for n in nodes_list:
kernel_attr_name = graph.fw_info.get_kernel_op_attributes(n.type)
kernel_attr_name = None if kernel_attr_name is None or len(kernel_attr_name) == 0 else kernel_attr_name[0]

if kernel_attr_name is not None and n.is_weights_quantization_enabled(kernel_attr_name) and \
all([c.weights_quantization_cfg.get_attr_config(kernel_attr_name).weights_error_method ==
QuantizationErrorMethod.HMSE for c in n.candidates_quantization_cfg]):
hmse_nodes.append(n)

return hmse_nodes


def calculate_quantization_params(graph: Graph,
nodes: List[BaseNode] = [],
specific_nodes: bool = False,
Expand Down Expand Up @@ -58,6 +84,17 @@ def calculate_quantization_params(graph: Graph,
# Create a list of nodes to compute their thresholds
nodes_list: List[BaseNode] = nodes if specific_nodes else graph.nodes()

# Collecting nodes that are configured to search weights quantization parameters using HMSE optimization
# and computing required Hessian information to be used for HMSE parameters selection.
# The Hessian scores are computed and stored in the hessian_info_service object.
nodes_for_hmse = _collect_nodes_for_hmse(nodes_list, graph)
if len(nodes_for_hmse) > 0:
hessian_info_service.fetch_hessian(TraceHessianRequest(mode=HessianMode.WEIGHTS,
granularity=HessianInfoGranularity.PER_ELEMENT,
target_nodes=nodes_for_hmse),
required_size=num_hessian_samples,
batch_size=1)

for n in tqdm(nodes_list, "Calculating quantization parameters"): # iterate only nodes that we should compute their thresholds
for candidate_qc in n.candidates_quantization_cfg:
for attr in n.get_node_weights_attributes():
Expand All @@ -73,6 +110,8 @@ def calculate_quantization_params(graph: Graph,
mod_attr_cfg = attr_cfg

if attr_cfg.weights_error_method == QuantizationErrorMethod.HMSE:
# Although we collected nodes for HMSE before running the loop, we keep this verification to
# notify the user in case of HMSE configured for node that is not compatible for this method
kernel_attr_name = graph.fw_info.get_kernel_op_attributes(n.type)
if len(kernel_attr_name) > 0:
kernel_attr_name = kernel_attr_name[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

import numpy as np
import tensorflow as tf
from tqdm import tqdm
from typing import List

from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, HESSIAN_EPS
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianInfoGranularity
from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
Expand Down Expand Up @@ -47,11 +48,6 @@ def __init__(self,
num_iterations_for_approximation: Number of iterations to use when approximating the Hessian trace.
"""

if len(trace_hessian_request.target_nodes) > 1: # pragma: no cover
Logger.critical(f"Weights Hessian approximation is currently supported only for a single target node,"
f" but the provided request contains the following target nodes: "
f"{trace_hessian_request.target_nodes}.")

super(WeightsTraceHessianCalculatorKeras, self).__init__(graph=graph,
input_images=input_images,
fw_impl=fw_impl,
Expand All @@ -73,35 +69,12 @@ def compute(self) -> List[np.ndarray]:
The function returns a list for compatibility reasons.

"""
# Check if the target node's layer type is supported.
# We assume that weights Hessian computation is done only for a single node at each request.
target_node = self.hessian_request.target_nodes[0]
if not DEFAULT_KERAS_INFO.is_kernel_op(target_node.type):
Logger.critical(f"Hessian information with respect to weights is not supported for "
f"{target_node.type} layers.") # pragma: no cover

# Construct the Keras float model for inference
model, _ = FloatKerasModelBuilder(graph=self.graph).build_model()

# Get the weight attributes for the target node type
weight_attributes = DEFAULT_KERAS_INFO.get_kernel_op_attributes(target_node.type)

# Get the weight tensor for the target node
if len(weight_attributes) != 1: # pragma: no cover
Logger.critical(f"Hessian-based scoring with respect to weights is currently supported only for nodes with "
f"a single weight attribute. Found {len(weight_attributes)} attributes.")

weight_tensor = getattr(model.get_layer(target_node.name), weight_attributes[0])

# Get the output channel index (needed for HessianInfoGranularity.PER_OUTPUT_CHANNEL case)
output_channel_axis, _ = DEFAULT_KERAS_INFO.kernel_channels_mapping.get(target_node.type)

# Get number of scores that should be calculated by the granularity.
num_of_scores = self._get_num_scores_by_granularity(weight_tensor,
output_channel_axis)

# Initiate a gradient tape for automatic differentiation
with tf.GradientTape(persistent=True) as tape:
with (tf.GradientTape(persistent=True) as tape):
# Perform a forward pass (inference) to get the output, while watching
# the input tensor for gradient computation
tape.watch(self.input_images)
Expand All @@ -110,55 +83,97 @@ def compute(self) -> List[np.ndarray]:
# Combine outputs if the model returns multiple output tensors
output = self._concat_tensors(outputs)

approximation_per_iteration = []
for j in range(self.num_iterations_for_approximation): # Approximation iterations
ipts_hessian_trace_approx = [tf.Variable([0.0], dtype=tf.float32, trainable=True)
for _ in range(len(self.hessian_request.target_nodes))]

prev_mean_results = None
tensors_original_shape = []
for j in tqdm(range(self.num_iterations_for_approximation)): # Approximation iterations
# Getting a random vector with normal distribution and the same shape as the model output
v = tf.random.normal(shape=output.shape)
f_v = tf.reduce_sum(v * output)

# Stop recording operations for automatic differentiation
for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor

# Check if the target node's layer type is supported.
if not DEFAULT_KERAS_INFO.is_kernel_op(ipt_node.type):
Logger.critical(f"Hessian information with respect to weights is not supported for "
f"{ipt_node.type} layers.") # pragma: no cover

# Get the weight attributes for the target node type
weight_attributes = DEFAULT_KERAS_INFO.get_kernel_op_attributes(ipt_node.type)

# Get the weight tensor for the target node
if len(weight_attributes) != 1: # pragma: no cover
Logger.critical(
f"Hessian-based scoring with respect to weights is currently supported only for nodes with "
f"a single weight attribute. Found {len(weight_attributes)} attributes.")

weight_tensor = getattr(model.get_layer(ipt_node.name), weight_attributes[0])

if j == 0:
# On the first iteration we store the weight_tensor shape for later reshaping the results
# back if necessary
tensors_original_shape.append(weight_tensor.shape)

# Get the output channel index (needed for HessianInfoGranularity.PER_OUTPUT_CHANNEL case)
output_channel_axis, _ = DEFAULT_KERAS_INFO.kernel_channels_mapping.get(ipt_node.type)

# Get number of scores that should be calculated by the granularity.
num_of_scores = self._get_num_scores_by_granularity(weight_tensor,
output_channel_axis)

# Stop recording operations for automatic differentiation
with tape.stop_recording():
# Compute gradients of f_v with respect to the weights
gradients = tape.gradient(f_v, weight_tensor)
gradients = self._reshape_gradients(gradients,
output_channel_axis,
num_of_scores)

approx = tf.reduce_sum(tf.pow(gradients, 2.0), axis=1)

# Update node Hessian approximation mean over random iterations
ipts_hessian_trace_approx[i] = (j * ipts_hessian_trace_approx[i] + approx) / (j + 1)

# Free gradients
del gradients

# If the change to the mean approximation is insignificant (to all outputs)
# we stop the calculation.
with tape.stop_recording():
# Compute gradients of f_v with respect to the weights
gradients = tape.gradient(f_v, weight_tensor)
gradients = self._reshape_gradients(gradients,
output_channel_axis,
num_of_scores)
approx = tf.reduce_sum(tf.pow(gradients, 2.0), axis=1)

# Free gradients
del gradients

# If the change to the mean approximation is insignificant (to all outputs)
# we stop the calculation.
if j > MIN_HESSIAN_ITER:
# Compute new means and deltas
new_mean = tf.reduce_mean(tf.stack(approximation_per_iteration + approx), axis=0)
delta = new_mean - tf.reduce_mean(tf.stack(approximation_per_iteration), axis=0)
is_converged = np.all(np.abs(delta) / (np.abs(new_mean) + HESSIAN_EPS) < HESSIAN_COMP_TOLERANCE)
if is_converged:
approximation_per_iteration.append(approx)
break

approximation_per_iteration.append(approx)
if prev_mean_results is not None:
new_mean_res = \
tf.convert_to_tensor([tf.reduce_mean(res) for res in ipts_hessian_trace_approx])
relative_delta_per_node = (tf.abs(new_mean_res - prev_mean_results) /
(tf.abs(new_mean_res) + 1e-6))
max_delta = tf.reduce_max(relative_delta_per_node)
if max_delta < HESSIAN_COMP_TOLERANCE:
break

# Compute the mean of the approximations
final_approx = tf.reduce_mean(tf.stack(approximation_per_iteration), axis=0)
prev_mean_results = tf.convert_to_tensor([tf.reduce_mean(res) for res in ipts_hessian_trace_approx])

# Free gradient tape
del tape

if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR:
if final_approx.shape != (1,): # pragma: no cover
Logger.critical(f"For HessianInfoGranularity.PER_TENSOR, the expected score shape is (1,), but found {final_approx.shape}.")
for final_approx in ipts_hessian_trace_approx:
if final_approx.shape != (1,): # pragma: no cover
Logger.critical(f"For HessianInfoGranularity.PER_TENSOR, the expected score shape is (1,), "
f"but found {final_approx.shape}.")
elif self.hessian_request.granularity == HessianInfoGranularity.PER_ELEMENT:
# Reshaping the scores to the original weight shape
final_approx = tf.reshape(final_approx, weight_tensor.shape)
ipts_hessian_trace_approx = \
[tf.reshape(final_approx, s) for final_approx, s in
zip(ipts_hessian_trace_approx, tensors_original_shape)]

# Add a batch axis to the Hessian approximation tensor (to align with the expected returned shape)
# We assume per-image computation, so the batch axis size is 1.
final_approx = final_approx[np.newaxis, ...]
final_approx = [r_final_approx[np.newaxis, ...].numpy()
for r_final_approx in ipts_hessian_trace_approx]

return [final_approx.numpy()]
return final_approx

def _reshape_gradients(self,
gradients: tf.Tensor,
Expand Down
Loading
Loading