diff --git a/model_compression_toolkit/core/common/pruning/memory_calculator.py b/model_compression_toolkit/core/common/pruning/memory_calculator.py index 103f2b177..f39a25bf2 100644 --- a/model_compression_toolkit/core/common/pruning/memory_calculator.py +++ b/model_compression_toolkit/core/common/pruning/memory_calculator.py @@ -73,11 +73,11 @@ def get_nparams_of_shared_nodes(self, # Get the output mask for the node if it exists node_output_mask = masks.get(node) # Calculate the number of remaining parameters for the shared node after pruning - nparams += self.fw_impl.get_pruned_node_num_params(node, - node_input_mask, - node_output_mask, - self.fw_info, - include_padded_channels) + nparams += self.get_pruned_node_num_params(node, + node_input_mask, + node_output_mask, + self.fw_info, + include_padded_channels) return nparams def get_nparams_of_pruning_sections(self, masks, pruning_sections, include_padded_channels:bool): @@ -179,7 +179,7 @@ def _get_pruning_section_num_params(self, include_padded_channels: bool) -> int: # Number of params for the first node in the section. - first_node_nparams = self.fw_impl.get_pruned_node_num_params( + first_node_nparams = self.get_pruned_node_num_params( pruning_section.entry_node, pruning_section_mask.entry_node_ic_mask, pruning_section_mask.entry_node_oc_mask, @@ -188,7 +188,7 @@ def _get_pruning_section_num_params(self, # Sum number of params for all intermediate nodes in the section. total_inter_nodes_nparams = sum( - self.fw_impl.get_pruned_node_num_params( + self.get_pruned_node_num_params( inter_node, pruning_section_mask.entry_node_oc_mask, pruning_section_mask.entry_node_oc_mask, @@ -196,7 +196,7 @@ def _get_pruning_section_num_params(self, include_padded_channels) for inter_node in pruning_section.intermediate_nodes) # Number of params for the last node in the section. - second_node_nparams = self.fw_impl.get_pruned_node_num_params( + second_node_nparams = self.get_pruned_node_num_params( pruning_section.exit_node, pruning_section_mask.exit_node_ic_mask, pruning_section_mask.exit_node_oc_mask, @@ -204,3 +204,63 @@ def _get_pruning_section_num_params(self, include_padded_channels) return first_node_nparams + total_inter_nodes_nparams + second_node_nparams + + + def get_pruned_node_num_params(self, + node: BaseNode, + input_mask: np.ndarray, + output_mask: np.ndarray, + fw_info: FrameworkInfo, + include_padded_channels: bool): + """ + Calculates the number of parameters in a pruned node of a Keras model. + + Args: + node: The node whose parameters are to be counted. + input_mask: Mask to be applied to the input channels. + output_mask: Mask to be applied to the output channels. + fw_info: Framework-specific information object. + include_padded_channels: Boolean flag to include or exclude null channels in the count. + + Returns: + Integer representing the number of parameters in the pruned node. + """ + + total_params = 0 + if fw_info.is_kernel_op(node.type): + # Obtain axes info for kernel operations. + oc_axis, ic_axis = fw_info.kernel_channels_mapping.get(node.type) + kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0] + for w_attr, w in node.weights.items(): + # Check if the weight attribute is the kernel. + if kernel_attr in w_attr: + # Handle input and output masks, ensuring they are boolean arrays. + input_mask = np.ones(w.shape[ic_axis], dtype=bool) if input_mask is None else input_mask.astype(bool) + output_mask = np.ones(w.shape[oc_axis], dtype=bool) if output_mask is None else output_mask.astype(bool) + + # Assert the input and output masks match the kernel dimensions. + assert w.shape[ic_axis] == len(input_mask), (f"Kernel num of input channels: {w.shape[ic_axis]}, but mask len is {len(input_mask)} for node {node}") + assert w.shape[oc_axis] == len(output_mask), (f"Kernel num of output channels: {w.shape[oc_axis]}, but mask len is {len(output_mask)} for node {node}") + + # Apply masks to the kernel and calculate the remaining parameters. + pruned_w = np.take(w, np.where(input_mask)[0], axis=ic_axis) + pruned_w = np.take(pruned_w, np.where(output_mask)[0], axis=oc_axis) + total_params += len(pruned_w.flatten()) + else: + # For non-kernel weights, apply the output mask only. + total_params += len(np.take(w, np.where(output_mask)[0])) + + else: + # For non-kernel operations, apply the output mask to the last axis. + # This part assumes that for non-kernel ops, all weights output channel axis is -1. + for w_attr, w in node.weights.items(): + pruned_w = np.take(w, np.where(output_mask)[0], axis=-1) # TODO: get axis from fw-specific function + total_params += pruned_w.size + + if include_padded_channels: # TODO: remove duplicate + node_simd = node.get_simd() + nparams_per_oc = total_params / np.sum(output_mask) + num_oc_with_null_channels = np.ceil(np.sum(output_mask) / node_simd) * node_simd + total_params = num_oc_with_null_channels * nparams_per_oc + + return total_params \ No newline at end of file diff --git a/model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py b/model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py index 5a030a136..8159243e0 100644 --- a/model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +++ b/model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py @@ -123,28 +123,3 @@ def is_node_intermediate_pruning_section(self, raise NotImplemented(f'{self.__class__.__name__} have to implement the ' f'framework\'s is_node_intermediate_pruning_section method.') # pragma: no cover - @abstractmethod - def get_pruned_node_num_params(self, - node: BaseNode, - input_mask: np.ndarray, - output_mask: np.ndarray, - fw_info: FrameworkInfo, - include_padded_channels: bool): - """ - Abstract method to get the number of parameters of a pruned node. - - Args: - node: The node whose parameters are to be counted. - input_mask: Mask to be applied to the input channels. - output_mask: Mask to be applied to the output channels. - fw_info: Framework-specific information. - include_padded_channels: Boolean flag to include or exclude padded channels in the count. - - Returns: - int: Number of parameters after pruning. - - Raises: - NotImplemented: If the method is not implemented in the subclass. - """ - raise NotImplemented(f'{self.__class__.__name__} have to implement the ' - f'framework\'s get_pruned_node_num_params method.') # pragma: no cover diff --git a/model_compression_toolkit/core/keras/pruning/count_node_params.py b/model_compression_toolkit/core/keras/pruning/count_node_params.py index faa014d57..15d5510a3 100644 --- a/model_compression_toolkit/core/keras/pruning/count_node_params.py +++ b/model_compression_toolkit/core/keras/pruning/count_node_params.py @@ -1,65 +1,9 @@ -import copy - -import keras.layers -import numpy as np - -from model_compression_toolkit.core.common.framework_info import FrameworkInfo -from model_compression_toolkit.core.common import BaseNode - - -# Get the number of parameters for a pruned Keras node. -def get_keras_pruned_node_num_params(node: BaseNode, - input_mask: np.ndarray, - output_mask: np.ndarray, - fw_info: FrameworkInfo, - include_padded_channels: bool): # TODO: move to common - - total_params = 0 - if fw_info.is_kernel_op(node.type): - # Obtain axes info for kernel operations. - oc_axis, ic_axis = fw_info.kernel_channels_mapping.get(node.type) - kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0] - for w_attr, w in node.weights.items(): - # Check if the weight attribute is the kernel. - if kernel_attr in w_attr: - # Handle input and output masks, ensuring they are boolean arrays. - input_mask = np.ones(w.shape[ic_axis], dtype=bool) if input_mask is None else input_mask.astype(bool) - output_mask = np.ones(w.shape[oc_axis], dtype=bool) if output_mask is None else output_mask.astype(bool) - - # # Special handling for Dense layers to align input mask with kernel shape. - # if node.type == keras.layers.Dense: - # if w.shape[ic_axis] != len(input_mask): - # num_ic_per_prev_oc_channel = w.shape[ic_axis] / len(input_mask) - # assert int(num_ic_per_prev_oc_channel) == num_ic_per_prev_oc_channel - # input_mask = np.repeat(input_mask, int(num_ic_per_prev_oc_channel)) - - # Assert the input and output masks match the kernel dimensions. - assert w.shape[ic_axis] == len(input_mask), ( - f"Kernel num of input channels: {w.shape[ic_axis]}, but mask len is {len(input_mask)} for node " - f"{node}") - assert w.shape[oc_axis] == len( - output_mask), (f"Kernel num of output channels: {w.shape[oc_axis]}, but mask len is " - f"{len(output_mask)} for node {node}") - - # Apply masks to the kernel and calculate the remaining parameters. - pruned_w = np.take(w, np.where(input_mask)[0], axis=ic_axis) - pruned_w = np.take(pruned_w, np.where(output_mask)[0], axis=oc_axis) - total_params += len(pruned_w.flatten()) - else: - # For non-kernel weights, apply the output mask only. - total_params += len(np.take(w, np.where(output_mask)[0])) - - else: - # For non-kernel operations, apply the output mask to the last axis. - # This part assumes that for non-kernel ops, all weights output channel axis is -1. - for w_attr, w in node.weights.items(): - pruned_w = np.take(w, np.where(output_mask)[0], axis=-1) # TODO: get axis from fw-specific function - total_params += pruned_w.size - - if include_padded_channels: # TODO: remove duplicate - node_simd = node.get_simd() - nparams_per_oc = total_params / np.sum(output_mask) - num_oc_with_null_channels = np.ceil(np.sum(output_mask) / node_simd) * node_simd - total_params = num_oc_with_null_channels * nparams_per_oc - - return total_params +# import copy +# +# import keras.layers +# import numpy as np +# +# from model_compression_toolkit.core.common.framework_info import FrameworkInfo +# from model_compression_toolkit.core.common import BaseNode +# +# diff --git a/model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py b/model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py index fe5e8166d..56557da89 100644 --- a/model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +++ b/model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py @@ -5,7 +5,6 @@ from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation from model_compression_toolkit.core.keras.pruning.check_node_role import is_keras_node_intermediate_pruning_section, \ is_keras_entry_node, is_keras_exit_node -from model_compression_toolkit.core.keras.pruning.count_node_params import get_keras_pruned_node_num_params from model_compression_toolkit.core.keras.pruning.prune_keras_node import (prune_keras_exit_node, prune_keras_entry_node, \ prune_keras_intermediate_node) @@ -98,24 +97,3 @@ def is_node_intermediate_pruning_section(self, node): """ return is_keras_node_intermediate_pruning_section(node) - def get_pruned_node_num_params(self, - node: BaseNode, - input_mask: np.ndarray, - output_mask: np.ndarray, - fw_info: FrameworkInfo, - include_padded_channels: bool): - """ - Calculates the number of parameters in a pruned node of a Keras model. - - Args: - node: The node whose parameters are to be counted. - input_mask: Mask to be applied to the input channels. - output_mask: Mask to be applied to the output channels. - fw_info: Framework-specific information object. - include_padded_channels: Boolean flag to include or exclude null channels in the count. - - Returns: - Integer representing the number of parameters in the pruned node. - """ - return get_keras_pruned_node_num_params(node, input_mask, output_mask, fw_info, include_padded_channels) -