Skip to content

Commit

Permalink
Add comments to LFH importance matric
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Dec 3, 2023
1 parent a5a636e commit 521910a
Showing 1 changed file with 71 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,95 +2,127 @@

from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianMode, HessianInfoGranularity
from model_compression_toolkit.core.common.pruning.importance_metrics.base_importance_metric import BaseImportanceMetric
import numpy as np

from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation


class LFHImportanceMetric(BaseImportanceMetric):

def __init__(self,
graph:Graph,
graph: Graph,
representative_data_gen: Callable,
fw_impl: FrameworkImplementation,
fw_impl: PruningFrameworkImplementation,
pruning_config: PruningConfig,
fw_info: FrameworkInfo):

self.float_graph= graph
"""
Initializes the LFHImportanceMetric class for calculating Label-Free-Hessian
based importance scores of nodes in a graph.
Args:
graph (Graph): The computational graph of the neural network.
representative_data_gen (Callable): A generator function to produce representative data for the network.
fw_impl (PruningFrameworkImplementation): The specific framework implementation.
pruning_config (PruningConfig): Configuration settings for the pruning process.
fw_info (FrameworkInfo): Framework-specific information and utilities.
"""
self.float_graph = graph
self.representative_data_gen = representative_data_gen
self.fw_impl = fw_impl
self.pruning_config = pruning_config
self.fw_info = fw_info

def get_entry_node_to_score(self, sections_input_nodes:List[BaseNode]):
# Initialize services and variables for pruning process.
def get_entry_node_to_score(self, sections_input_nodes: List[BaseNode]):
"""
Calculates importance scores for each entry node in the provided list using the LFH method.
Args:
sections_input_nodes (List[BaseNode]): List of entry nodes for which to calculate importance scores.
Returns:
Dict[BaseNode, np.ndarray]: A dictionary mapping each entry node to its importance scores.
"""
# Initialize Hessian information service to calculate LFH scores
hessian_info_service = HessianInfoService(graph=self.float_graph,
representative_dataset=self.representative_data_gen,
fw_impl=self.fw_impl)

# Calculate the LFH (Label-Free Hessian) score for each prunable channel.
scores_per_prunable_node = hessian_info_service.fetch_scores_for_multiple_nodes(
mode=HessianMode.WEIGHTS,
granularity=HessianInfoGranularity.PER_OUTPUT_CHANNEL,
nodes=sections_input_nodes,
required_size=self.pruning_config.num_score_approximations)
# Fetch scores for multiple nodes in the graph
scores_per_prunable_node = hessian_info_service.fetch_scores_for_multiple_nodes(mode=HessianMode.WEIGHTS,
granularity=HessianInfoGranularity.PER_OUTPUT_CHANNEL,
nodes=sections_input_nodes,
required_size=self.pruning_config.num_score_approximations)


# Average the scores across approximations and map them to the corresponding nodes.
# Average scores across approximations and map them to nodes
entry_node_to_score = {node: np.mean(scores, axis=0) for node, scores in
zip(sections_input_nodes, scores_per_prunable_node)}

# Normalize scores using L2 norms and number of parameters
l2_oc_norm = self.get_l2_out_channel_norm(entry_nodes=sections_input_nodes)
count_oc_nparams = self.count_oc_nparams(entry_nodes=sections_input_nodes)
entry_node_to_score = self.normalize_lfh_scores(entry_node_to_score=entry_node_to_score,
entry_node_to_l2norm=l2_oc_norm,
entry_node_to_nparmas=count_oc_nparams)
entry_node_to_score = self.normalize_lfh_scores(entry_node_to_score, l2_oc_norm, count_oc_nparams)
return entry_node_to_score

def normalize_lfh_scores(self, entry_node_to_score, entry_node_to_l2norm, entry_node_to_nparmas):
"""
Normalizes the LFH scores for each node.
Args:
entry_node_to_score (Dict[BaseNode, np.ndarray]): Dictionary of nodes and their LFH scores.
entry_node_to_l2norm (Dict[BaseNode, np.ndarray]): Dictionary of nodes and their L2 norms.
entry_node_to_nparmas (Dict[BaseNode, np.ndarray]): Dictionary of nodes and their parameter counts.
def normalize_lfh_scores(self,
entry_node_to_score,
entry_node_to_l2norm,
entry_node_to_nparmas):
Returns:
Dict[BaseNode, np.ndarray]: Normalized scores for each node.
"""
new_scores = {}
# Normalize scores by multiplying with L2 norm and dividing by number of parameters
for node, trace_vector in entry_node_to_score.items():
new_scores[node] = trace_vector*entry_node_to_l2norm[node]/entry_node_to_nparmas[node]
new_scores[node] = trace_vector * entry_node_to_l2norm[node] / entry_node_to_nparmas[node]
return new_scores

def count_oc_nparams(self, entry_nodes: List[BaseNode]):
"""
Counts the number of parameters per output channel for each entry node.
Args:
entry_nodes (List[BaseNode]): List of entry nodes to count parameters for.
Returns:
Dict[BaseNode, np.ndarray]: Dictionary of nodes and their parameters count per output channel.
"""
node_channel_params = {}
for entry_node in entry_nodes:
kernel = entry_node.get_weights_by_keys('kernel')
ox_axis = self.fw_info.kernel_channels_mapping.get(entry_node.type)[0]

# Calculate the number of parameters for each output channel
params_per_channel = np.prod(kernel.shape) / kernel.shape[ox_axis]
oc_axis = self.fw_info.kernel_channels_mapping.get(entry_node.type)[0]

# Create an array with the number of parameters per channel
num_params_array = np.full(kernel.shape[ox_axis], params_per_channel)
# Calculate parameters per channel
params_per_channel = np.prod(kernel.shape) / kernel.shape[oc_axis]
num_params_array = np.full(kernel.shape[oc_axis], params_per_channel)

# Store in node_channel_params a dictionary from node to a np.array where
# each element corresponds to the number of parameters of this channel
node_channel_params[entry_node] = num_params_array

return node_channel_params


def get_l2_out_channel_norm(self, entry_nodes: List[BaseNode]):
"""
Computes the L2 norm for each output channel of the entry nodes.
Args:
entry_nodes (List[BaseNode]): List of entry nodes for L2 norm computation.
Returns:
Dict[BaseNode, np.ndarray]: Dictionary of nodes and their L2 norms for each output channel.
"""
node_l2_channel_norm = {}
for entry_node in entry_nodes:
kernel = entry_node.get_weights_by_keys('kernel')
ox_axis = self.fw_info.kernel_channels_mapping.get(entry_node.type)[0]

# Compute the l2 norm of each output channel
# Compute L2 norm for each channel
channels = np.split(kernel, indices_or_sections=kernel.shape[ox_axis], axis=ox_axis)
l2_norms = [np.linalg.norm(c.flatten(), ord=2) ** 2 for c in channels]

# Store in node_l2_channel_norm a dictionary from node to a np.array where
# each element corresponds to the l2 norm of this channel
node_l2_channel_norm[entry_node] = l2_norms

return node_l2_channel_norm
return node_l2_channel_norm

0 comments on commit 521910a

Please sign in to comment.