diff --git a/model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py b/model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py index 9963f3b5d..fb5bacf96 100644 --- a/model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +++ b/model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py @@ -2,95 +2,127 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo from model_compression_toolkit.core.common import Graph, BaseNode -from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianMode, HessianInfoGranularity from model_compression_toolkit.core.common.pruning.importance_metrics.base_importance_metric import BaseImportanceMetric import numpy as np from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig +from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation class LFHImportanceMetric(BaseImportanceMetric): - def __init__(self, - graph:Graph, + graph: Graph, representative_data_gen: Callable, - fw_impl: FrameworkImplementation, + fw_impl: PruningFrameworkImplementation, pruning_config: PruningConfig, fw_info: FrameworkInfo): - - self.float_graph= graph + """ + Initializes the LFHImportanceMetric class for calculating Label-Free-Hessian + based importance scores of nodes in a graph. + + Args: + graph (Graph): The computational graph of the neural network. + representative_data_gen (Callable): A generator function to produce representative data for the network. + fw_impl (PruningFrameworkImplementation): The specific framework implementation. + pruning_config (PruningConfig): Configuration settings for the pruning process. + fw_info (FrameworkInfo): Framework-specific information and utilities. + """ + self.float_graph = graph self.representative_data_gen = representative_data_gen self.fw_impl = fw_impl self.pruning_config = pruning_config self.fw_info = fw_info - def get_entry_node_to_score(self, sections_input_nodes:List[BaseNode]): - # Initialize services and variables for pruning process. + def get_entry_node_to_score(self, sections_input_nodes: List[BaseNode]): + """ + Calculates importance scores for each entry node in the provided list using the LFH method. + + Args: + sections_input_nodes (List[BaseNode]): List of entry nodes for which to calculate importance scores. + + Returns: + Dict[BaseNode, np.ndarray]: A dictionary mapping each entry node to its importance scores. + """ + # Initialize Hessian information service to calculate LFH scores hessian_info_service = HessianInfoService(graph=self.float_graph, representative_dataset=self.representative_data_gen, fw_impl=self.fw_impl) - # Calculate the LFH (Label-Free Hessian) score for each prunable channel. - scores_per_prunable_node = hessian_info_service.fetch_scores_for_multiple_nodes( - mode=HessianMode.WEIGHTS, - granularity=HessianInfoGranularity.PER_OUTPUT_CHANNEL, - nodes=sections_input_nodes, - required_size=self.pruning_config.num_score_approximations) + # Fetch scores for multiple nodes in the graph + scores_per_prunable_node = hessian_info_service.fetch_scores_for_multiple_nodes(mode=HessianMode.WEIGHTS, + granularity=HessianInfoGranularity.PER_OUTPUT_CHANNEL, + nodes=sections_input_nodes, + required_size=self.pruning_config.num_score_approximations) - - # Average the scores across approximations and map them to the corresponding nodes. + # Average scores across approximations and map them to nodes entry_node_to_score = {node: np.mean(scores, axis=0) for node, scores in zip(sections_input_nodes, scores_per_prunable_node)} + # Normalize scores using L2 norms and number of parameters l2_oc_norm = self.get_l2_out_channel_norm(entry_nodes=sections_input_nodes) count_oc_nparams = self.count_oc_nparams(entry_nodes=sections_input_nodes) - entry_node_to_score = self.normalize_lfh_scores(entry_node_to_score=entry_node_to_score, - entry_node_to_l2norm=l2_oc_norm, - entry_node_to_nparmas=count_oc_nparams) + entry_node_to_score = self.normalize_lfh_scores(entry_node_to_score, l2_oc_norm, count_oc_nparams) return entry_node_to_score + def normalize_lfh_scores(self, entry_node_to_score, entry_node_to_l2norm, entry_node_to_nparmas): + """ + Normalizes the LFH scores for each node. + + Args: + entry_node_to_score (Dict[BaseNode, np.ndarray]): Dictionary of nodes and their LFH scores. + entry_node_to_l2norm (Dict[BaseNode, np.ndarray]): Dictionary of nodes and their L2 norms. + entry_node_to_nparmas (Dict[BaseNode, np.ndarray]): Dictionary of nodes and their parameter counts. - def normalize_lfh_scores(self, - entry_node_to_score, - entry_node_to_l2norm, - entry_node_to_nparmas): + Returns: + Dict[BaseNode, np.ndarray]: Normalized scores for each node. + """ new_scores = {} + # Normalize scores by multiplying with L2 norm and dividing by number of parameters for node, trace_vector in entry_node_to_score.items(): - new_scores[node] = trace_vector*entry_node_to_l2norm[node]/entry_node_to_nparmas[node] + new_scores[node] = trace_vector * entry_node_to_l2norm[node] / entry_node_to_nparmas[node] return new_scores def count_oc_nparams(self, entry_nodes: List[BaseNode]): + """ + Counts the number of parameters per output channel for each entry node. + + Args: + entry_nodes (List[BaseNode]): List of entry nodes to count parameters for. + + Returns: + Dict[BaseNode, np.ndarray]: Dictionary of nodes and their parameters count per output channel. + """ node_channel_params = {} for entry_node in entry_nodes: kernel = entry_node.get_weights_by_keys('kernel') - ox_axis = self.fw_info.kernel_channels_mapping.get(entry_node.type)[0] - - # Calculate the number of parameters for each output channel - params_per_channel = np.prod(kernel.shape) / kernel.shape[ox_axis] + oc_axis = self.fw_info.kernel_channels_mapping.get(entry_node.type)[0] - # Create an array with the number of parameters per channel - num_params_array = np.full(kernel.shape[ox_axis], params_per_channel) + # Calculate parameters per channel + params_per_channel = np.prod(kernel.shape) / kernel.shape[oc_axis] + num_params_array = np.full(kernel.shape[oc_axis], params_per_channel) - # Store in node_channel_params a dictionary from node to a np.array where - # each element corresponds to the number of parameters of this channel node_channel_params[entry_node] = num_params_array - return node_channel_params - def get_l2_out_channel_norm(self, entry_nodes: List[BaseNode]): + """ + Computes the L2 norm for each output channel of the entry nodes. + + Args: + entry_nodes (List[BaseNode]): List of entry nodes for L2 norm computation. + + Returns: + Dict[BaseNode, np.ndarray]: Dictionary of nodes and their L2 norms for each output channel. + """ node_l2_channel_norm = {} for entry_node in entry_nodes: kernel = entry_node.get_weights_by_keys('kernel') ox_axis = self.fw_info.kernel_channels_mapping.get(entry_node.type)[0] - # Compute the l2 norm of each output channel + # Compute L2 norm for each channel channels = np.split(kernel, indices_or_sections=kernel.shape[ox_axis], axis=ox_axis) l2_norms = [np.linalg.norm(c.flatten(), ord=2) ** 2 for c in channels] - # Store in node_l2_channel_norm a dictionary from node to a np.array where - # each element corresponds to the l2 norm of this channel node_l2_channel_norm[entry_node] = l2_norms - - return node_l2_channel_norm \ No newline at end of file + return node_l2_channel_norm