Skip to content

Commit

Permalink
Refactor pruned node params count
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Dec 3, 2023
1 parent e64047e commit 8dcbf85
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self,
self.tpc = tpc

# Initialize the SIMD group indices and scores dictionaries.
self.simd_groups_indices = {}
self.simd_groups_indices = {} # TODO: Take SIMD grouping out of mask calculator
self.simd_groups_scores = {}
self.mask_simd = None # Will hold SIMD group mask per node.
self.mask = None # Will hold the final mask to be applied to the nodes.
Expand Down
217 changes: 127 additions & 90 deletions model_compression_toolkit/core/common/pruning/memory_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,11 @@
from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import \
PruningFrameworkImplementation
from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection, PruningSectionMask
from model_compression_toolkit.logger import Logger


class MemoryCalculator:
"""
MemoryCalculator is a class that computes the memory usage of a pruned graph.
It takes into account the pruning masks applied to each node and computes the memory
accordingly.
"""


def __init__(self,
graph: Graph,
Expand All @@ -34,13 +31,31 @@ def __init__(self,
def get_pruned_graph_memory(self,
masks: Dict[BaseNode, np.ndarray],
include_padded_channels: bool) -> float:
"""
Args:
masks:
include_padded_channels:
Returns:
"""
nparams = self.get_pruned_graph_num_params(masks=masks,
include_padded_channels=include_padded_channels)
return nparams * 4.

def get_pruned_graph_num_params(self,
masks: Dict[BaseNode, np.ndarray],
include_padded_channels: bool) -> int:
"""
Args:
masks:
include_padded_channels:
Returns:
"""

# Total number of parameters after pruning
total_nparams = 0
Expand All @@ -63,6 +78,16 @@ def get_nparams_of_shared_nodes(self,
masks: Dict[BaseNode, np.ndarray],
pruning_sections: List[PruningSection],
include_padded_channels) -> int:
"""
Args:
masks:
pruning_sections:
include_padded_channels:
Returns:
"""

nparams = 0
# Identify nodes that are at the end of one section and the start of another
Expand All @@ -76,16 +101,19 @@ def get_nparams_of_shared_nodes(self,
nparams += self.get_pruned_node_num_params(node,
node_input_mask,
node_output_mask,
self.fw_info,
include_padded_channels)
return nparams

def get_nparams_of_pruning_sections(self, masks, pruning_sections, include_padded_channels:bool):
def get_nparams_of_pruning_sections(self,
masks,
pruning_sections,
include_padded_channels: bool):
"""
Args:
masks:
pruning_sections:
include_padded_channels:
Returns:
Expand All @@ -100,17 +128,11 @@ def get_nparams_of_pruning_sections(self, masks, pruning_sections, include_padde
include_padded_channels)
return nparams

def get_section_mask_from_node_mask(self, masks, pruning_section, pruning_sections):
"""
def get_section_mask_from_node_mask(self,
masks,
pruning_section,
pruning_sections):

Args:
masks:
pruning_section:
pruning_sections:
Returns:
"""
# Determine masks for input channels of the first node and output channels of the second node.
first_node_input_channels_mask = self._get_node_input_mask(pruning_section.entry_node,
pruning_sections,
Expand All @@ -125,50 +147,49 @@ def get_section_mask_from_node_mask(self, masks, pruning_section, pruning_sectio

return pruning_section_mask

def get_nparams_of_nonpruned_nodes(self, pruning_sections, include_padded_channels:bool):
def get_nparams_of_nonpruned_nodes(self,
pruning_sections,
include_padded_channels: bool):
"""
Args:
pruning_sections:
include_padded_channels:
Returns:
"""

total_nparams = 0
# Collect all nodes to prune from the pruning sections.
nodes_to_prune = set([node for section in pruning_sections for node in section.get_all_nodes()])
# Calculate the num of params for non-prunable nodes.
for n in self.graph.nodes:
if n not in nodes_to_prune:
node_nparams = sum(n.get_num_parameters(self.fw_info))
if include_padded_channels:
num_oc = n.output_shape[-1]
nparams_per_oc = node_nparams/num_oc
num_oc_include_null_channels = np.ceil(num_oc/n.get_simd())*n.get_simd()
node_nparams = num_oc_include_null_channels*nparams_per_oc
# node_nparams = sum(n.get_num_parameters(self.fw_info))
node_nparams = self.get_pruned_node_num_params(node=n,
input_mask=None,
output_mask=None,
include_padded_channels=include_padded_channels)
# if include_padded_channels:
# node_nparams = self.get_node_nparams_with_padded_channels(node_nparams=node_nparams,
# num_oc=n.output_shape[-1],
# node_simd=n.get_simd())
total_nparams += node_nparams
return total_nparams

def _get_node_input_mask(self,
node: BaseNode, pruning_sections: List[PruningSection],
masks: Dict[BaseNode, np.ndarray]) -> np.ndarray:
"""

Args:
node:
pruning_sections:
masks:
Returns:
"""
for section in pruning_sections:
if node == section.exit_node:
return masks.get(section.entry_node)
return None

def _get_nodes_from_adjacent_sections(self, pruning_sections: List[PruningSection]) -> List[BaseNode]:
"""
Args:
pruning_sections:
def _get_nodes_from_adjacent_sections(self,
pruning_sections: List[PruningSection]) -> List[BaseNode]:

Returns:
"""
input_nodes = set(section.entry_node for section in pruning_sections)
output_nodes = set(section.exit_node for section in pruning_sections)
return list(input_nodes.intersection(output_nodes))
Expand All @@ -183,7 +204,6 @@ def _get_pruning_section_num_params(self,
pruning_section.entry_node,
pruning_section_mask.entry_node_ic_mask,
pruning_section_mask.entry_node_oc_mask,
self.fw_info,
include_padded_channels)

# Sum number of params for all intermediate nodes in the section.
Expand All @@ -192,75 +212,92 @@ def _get_pruning_section_num_params(self,
inter_node,
pruning_section_mask.entry_node_oc_mask,
pruning_section_mask.entry_node_oc_mask,
self.fw_info,
include_padded_channels) for inter_node in pruning_section.intermediate_nodes)

# Number of params for the last node in the section.
second_node_nparams = self.get_pruned_node_num_params(
pruning_section.exit_node,
pruning_section_mask.exit_node_ic_mask,
pruning_section_mask.exit_node_oc_mask,
self.fw_info,
include_padded_channels)

return first_node_nparams + total_inter_nodes_nparams + second_node_nparams


def get_pruned_node_num_params(self,
node: BaseNode,
input_mask: np.ndarray,
output_mask: np.ndarray,
fw_info: FrameworkInfo,
include_padded_channels: bool):
node: BaseNode,
input_mask: np.ndarray,
output_mask: np.ndarray,
include_padded_channels: bool):
"""
Calculates the number of parameters in a pruned node of a Keras model.
Calculates the number of parameters in a pruned node of a model.
Args:
node: The node whose parameters are to be counted.
input_mask: Mask to be applied to the input channels.
output_mask: Mask to be applied to the output channels.
fw_info: Framework-specific information object.
include_padded_channels: Boolean flag to include or exclude null channels in the count.
include_padded_channels: Boolean flag to include or exclude padded channels (due to SIMD) in the count.
Returns:
Integer representing the number of parameters in the pruned node.
"""

def _prune(w, mask, axis):
mask = np.ones(w.shape[axis], dtype=bool) if mask is None else mask.astype(bool)
assert w.shape[axis] == len(mask), (
f"Kernel num of input channels: {w.shape[axis]}, but mask len is {len(mask)} for node {node}")
pruned_w = np.take(w, np.where(mask)[0], axis=axis)
return pruned_w

total_params = 0
if fw_info.is_kernel_op(node.type):
# Obtain axes info for kernel operations.
oc_axis, ic_axis = fw_info.kernel_channels_mapping.get(node.type)
kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
for w_attr, w in node.weights.items():
# Check if the weight attribute is the kernel.
if kernel_attr in w_attr:
# Handle input and output masks, ensuring they are boolean arrays.
input_mask = np.ones(w.shape[ic_axis], dtype=bool) if input_mask is None else input_mask.astype(bool)
output_mask = np.ones(w.shape[oc_axis], dtype=bool) if output_mask is None else output_mask.astype(bool)

# Assert the input and output masks match the kernel dimensions.
assert w.shape[ic_axis] == len(input_mask), (f"Kernel num of input channels: {w.shape[ic_axis]}, but mask len is {len(input_mask)} for node {node}")
assert w.shape[oc_axis] == len(output_mask), (f"Kernel num of output channels: {w.shape[oc_axis]}, but mask len is {len(output_mask)} for node {node}")

# Apply masks to the kernel and calculate the remaining parameters.
pruned_w = np.take(w, np.where(input_mask)[0], axis=ic_axis)
pruned_w = np.take(pruned_w, np.where(output_mask)[0], axis=oc_axis)
total_params += len(pruned_w.flatten())
else:
# For non-kernel weights, apply the output mask only.
total_params += len(np.take(w, np.where(output_mask)[0]))

else:
# For non-kernel operations, apply the output mask to the last axis.
# This part assumes that for non-kernel ops, all weights output channel axis is -1.
for w_attr, w in node.weights.items():
pruned_w = np.take(w, np.where(output_mask)[0], axis=-1) # TODO: get axis from fw-specific function
total_params += pruned_w.size

if include_padded_channels: # TODO: remove duplicate
node_simd = node.get_simd()
nparams_per_oc = total_params / np.sum(output_mask)
num_oc_with_null_channels = np.ceil(np.sum(output_mask) / node_simd) * node_simd
total_params = num_oc_with_null_channels * nparams_per_oc

return total_params
attributes_and_oc_axis = self.fw_impl.get_node_attributes_with_io_axis(node, self.fw_info)
for w_attr, w in node.weights.items():
io_axis = [io_axis for attr, io_axis in attributes_and_oc_axis.items() if attr in w_attr]
assert len(io_axis) == 1
out_axis, in_axis = io_axis[0]
if in_axis is not None and input_mask is not None:
w = _prune(w, input_mask, in_axis)
if out_axis is not None and output_mask is not None:
w = _prune(w, output_mask, out_axis)
total_params += w.size

num_oc = np.sum(output_mask) if output_mask is not None else node.output_shape[-1]
if include_padded_channels:
total_params = self.get_node_nparams_with_padded_channels(node=node,
node_nparams=total_params,
num_oc=num_oc,
node_simd=node.get_simd())

return total_params

def get_node_nparams_with_padded_channels(self,
node: BaseNode,
node_nparams: int,
num_oc: int,
node_simd: int):
"""
Args:
node_nparams:
num_oc:
node_simd:
Returns:
"""
nparams_per_oc = node_nparams / num_oc

"""
Usually every layer has some number of params in each weight tensor dedicated for a single output-channel.
Sometimes not. For example: Keras Normalize layer with 3 output channels has 3 weights where 2 of them are
tensors of length 3 and a single scalar that is used for all 3 output channels.
"""
if int(nparams_per_oc)!=nparams_per_oc:
Logger.warning(
f" Found a node {node.name} with weights that are not uniformly distributed "
f"across output channels, thus memory calculation may be inaccurate due to "
f"SIMD assumptions.")
nparams_per_oc = np.ceil(nparams_per_oc)
# assert int(nparams_per_oc)==nparams_per_oc, f"Expected number of params per channel to be integer but is {nparams_per_oc}"

num_oc_with_null_channels = np.ceil(num_oc / node_simd) * node_simd
return num_oc_with_null_channels * nparams_per_oc
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,5 @@ def __init__(self,

# The strategy to use when deciding which channels to prune based on their importance scores.
self.channels_filtering_strategy = channels_filtering_strategy

# TODO: Consider limiting ratio per layer
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Tuple

from abc import abstractmethod

from model_compression_toolkit.core.common.framework_info import FrameworkInfo
Expand Down Expand Up @@ -123,3 +125,17 @@ def is_node_intermediate_pruning_section(self,
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s is_node_intermediate_pruning_section method.') # pragma: no cover

def get_node_attributes_with_io_axis(self, node: BaseNode, fw_info: FrameworkInfo):
"""
Gets the attributes of a node and the axis for each attribute's output channels dimension.
Args:
node (BaseNode): The node for which attributes and their output channel axis are required.
fw_info (FrameworkInfo): Framework-specific information containing details about layers and attributes.
Returns:
List[Tuple[str, int]]: A list of tuples where each tuple contains an attribute name and the axis
of the output channels for that attribute.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_node_attributes_with_output_axis method.') # pragma: no cover
Loading

0 comments on commit 8dcbf85

Please sign in to comment.