Skip to content

Commit

Permalink
move node count params to common memory calc
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Dec 3, 2023
1 parent 74e4d86 commit e64047e
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 120 deletions.
76 changes: 68 additions & 8 deletions model_compression_toolkit/core/common/pruning/memory_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def get_nparams_of_shared_nodes(self,
# Get the output mask for the node if it exists
node_output_mask = masks.get(node)
# Calculate the number of remaining parameters for the shared node after pruning
nparams += self.fw_impl.get_pruned_node_num_params(node,
node_input_mask,
node_output_mask,
self.fw_info,
include_padded_channels)
nparams += self.get_pruned_node_num_params(node,
node_input_mask,
node_output_mask,
self.fw_info,
include_padded_channels)
return nparams

def get_nparams_of_pruning_sections(self, masks, pruning_sections, include_padded_channels:bool):
Expand Down Expand Up @@ -179,7 +179,7 @@ def _get_pruning_section_num_params(self,
include_padded_channels: bool) -> int:

# Number of params for the first node in the section.
first_node_nparams = self.fw_impl.get_pruned_node_num_params(
first_node_nparams = self.get_pruned_node_num_params(
pruning_section.entry_node,
pruning_section_mask.entry_node_ic_mask,
pruning_section_mask.entry_node_oc_mask,
Expand All @@ -188,19 +188,79 @@ def _get_pruning_section_num_params(self,

# Sum number of params for all intermediate nodes in the section.
total_inter_nodes_nparams = sum(
self.fw_impl.get_pruned_node_num_params(
self.get_pruned_node_num_params(
inter_node,
pruning_section_mask.entry_node_oc_mask,
pruning_section_mask.entry_node_oc_mask,
self.fw_info,
include_padded_channels) for inter_node in pruning_section.intermediate_nodes)

# Number of params for the last node in the section.
second_node_nparams = self.fw_impl.get_pruned_node_num_params(
second_node_nparams = self.get_pruned_node_num_params(
pruning_section.exit_node,
pruning_section_mask.exit_node_ic_mask,
pruning_section_mask.exit_node_oc_mask,
self.fw_info,
include_padded_channels)

return first_node_nparams + total_inter_nodes_nparams + second_node_nparams


def get_pruned_node_num_params(self,
node: BaseNode,
input_mask: np.ndarray,
output_mask: np.ndarray,
fw_info: FrameworkInfo,
include_padded_channels: bool):
"""
Calculates the number of parameters in a pruned node of a Keras model.
Args:
node: The node whose parameters are to be counted.
input_mask: Mask to be applied to the input channels.
output_mask: Mask to be applied to the output channels.
fw_info: Framework-specific information object.
include_padded_channels: Boolean flag to include or exclude null channels in the count.
Returns:
Integer representing the number of parameters in the pruned node.
"""

total_params = 0
if fw_info.is_kernel_op(node.type):
# Obtain axes info for kernel operations.
oc_axis, ic_axis = fw_info.kernel_channels_mapping.get(node.type)
kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
for w_attr, w in node.weights.items():
# Check if the weight attribute is the kernel.
if kernel_attr in w_attr:
# Handle input and output masks, ensuring they are boolean arrays.
input_mask = np.ones(w.shape[ic_axis], dtype=bool) if input_mask is None else input_mask.astype(bool)
output_mask = np.ones(w.shape[oc_axis], dtype=bool) if output_mask is None else output_mask.astype(bool)

# Assert the input and output masks match the kernel dimensions.
assert w.shape[ic_axis] == len(input_mask), (f"Kernel num of input channels: {w.shape[ic_axis]}, but mask len is {len(input_mask)} for node {node}")
assert w.shape[oc_axis] == len(output_mask), (f"Kernel num of output channels: {w.shape[oc_axis]}, but mask len is {len(output_mask)} for node {node}")

# Apply masks to the kernel and calculate the remaining parameters.
pruned_w = np.take(w, np.where(input_mask)[0], axis=ic_axis)
pruned_w = np.take(pruned_w, np.where(output_mask)[0], axis=oc_axis)
total_params += len(pruned_w.flatten())
else:
# For non-kernel weights, apply the output mask only.
total_params += len(np.take(w, np.where(output_mask)[0]))

else:
# For non-kernel operations, apply the output mask to the last axis.
# This part assumes that for non-kernel ops, all weights output channel axis is -1.
for w_attr, w in node.weights.items():
pruned_w = np.take(w, np.where(output_mask)[0], axis=-1) # TODO: get axis from fw-specific function
total_params += pruned_w.size

if include_padded_channels: # TODO: remove duplicate
node_simd = node.get_simd()
nparams_per_oc = total_params / np.sum(output_mask)
num_oc_with_null_channels = np.ceil(np.sum(output_mask) / node_simd) * node_simd
total_params = num_oc_with_null_channels * nparams_per_oc

return total_params
Original file line number Diff line number Diff line change
Expand Up @@ -123,28 +123,3 @@ def is_node_intermediate_pruning_section(self,
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s is_node_intermediate_pruning_section method.') # pragma: no cover

@abstractmethod
def get_pruned_node_num_params(self,
node: BaseNode,
input_mask: np.ndarray,
output_mask: np.ndarray,
fw_info: FrameworkInfo,
include_padded_channels: bool):
"""
Abstract method to get the number of parameters of a pruned node.
Args:
node: The node whose parameters are to be counted.
input_mask: Mask to be applied to the input channels.
output_mask: Mask to be applied to the output channels.
fw_info: Framework-specific information.
include_padded_channels: Boolean flag to include or exclude padded channels in the count.
Returns:
int: Number of parameters after pruning.
Raises:
NotImplemented: If the method is not implemented in the subclass.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_pruned_node_num_params method.') # pragma: no cover
74 changes: 9 additions & 65 deletions model_compression_toolkit/core/keras/pruning/count_node_params.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,9 @@
import copy

import keras.layers
import numpy as np

from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common import BaseNode


# Get the number of parameters for a pruned Keras node.
def get_keras_pruned_node_num_params(node: BaseNode,
input_mask: np.ndarray,
output_mask: np.ndarray,
fw_info: FrameworkInfo,
include_padded_channels: bool): # TODO: move to common

total_params = 0
if fw_info.is_kernel_op(node.type):
# Obtain axes info for kernel operations.
oc_axis, ic_axis = fw_info.kernel_channels_mapping.get(node.type)
kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
for w_attr, w in node.weights.items():
# Check if the weight attribute is the kernel.
if kernel_attr in w_attr:
# Handle input and output masks, ensuring they are boolean arrays.
input_mask = np.ones(w.shape[ic_axis], dtype=bool) if input_mask is None else input_mask.astype(bool)
output_mask = np.ones(w.shape[oc_axis], dtype=bool) if output_mask is None else output_mask.astype(bool)

# # Special handling for Dense layers to align input mask with kernel shape.
# if node.type == keras.layers.Dense:
# if w.shape[ic_axis] != len(input_mask):
# num_ic_per_prev_oc_channel = w.shape[ic_axis] / len(input_mask)
# assert int(num_ic_per_prev_oc_channel) == num_ic_per_prev_oc_channel
# input_mask = np.repeat(input_mask, int(num_ic_per_prev_oc_channel))

# Assert the input and output masks match the kernel dimensions.
assert w.shape[ic_axis] == len(input_mask), (
f"Kernel num of input channels: {w.shape[ic_axis]}, but mask len is {len(input_mask)} for node "
f"{node}")
assert w.shape[oc_axis] == len(
output_mask), (f"Kernel num of output channels: {w.shape[oc_axis]}, but mask len is "
f"{len(output_mask)} for node {node}")

# Apply masks to the kernel and calculate the remaining parameters.
pruned_w = np.take(w, np.where(input_mask)[0], axis=ic_axis)
pruned_w = np.take(pruned_w, np.where(output_mask)[0], axis=oc_axis)
total_params += len(pruned_w.flatten())
else:
# For non-kernel weights, apply the output mask only.
total_params += len(np.take(w, np.where(output_mask)[0]))

else:
# For non-kernel operations, apply the output mask to the last axis.
# This part assumes that for non-kernel ops, all weights output channel axis is -1.
for w_attr, w in node.weights.items():
pruned_w = np.take(w, np.where(output_mask)[0], axis=-1) # TODO: get axis from fw-specific function
total_params += pruned_w.size

if include_padded_channels: # TODO: remove duplicate
node_simd = node.get_simd()
nparams_per_oc = total_params / np.sum(output_mask)
num_oc_with_null_channels = np.ceil(np.sum(output_mask) / node_simd) * node_simd
total_params = num_oc_with_null_channels * nparams_per_oc

return total_params
# import copy
#
# import keras.layers
# import numpy as np
#
# from model_compression_toolkit.core.common.framework_info import FrameworkInfo
# from model_compression_toolkit.core.common import BaseNode
#
#
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
from model_compression_toolkit.core.keras.pruning.check_node_role import is_keras_node_intermediate_pruning_section, \
is_keras_entry_node, is_keras_exit_node
from model_compression_toolkit.core.keras.pruning.count_node_params import get_keras_pruned_node_num_params
from model_compression_toolkit.core.keras.pruning.prune_keras_node import (prune_keras_exit_node,
prune_keras_entry_node, \
prune_keras_intermediate_node)
Expand Down Expand Up @@ -98,24 +97,3 @@ def is_node_intermediate_pruning_section(self, node):
"""
return is_keras_node_intermediate_pruning_section(node)

def get_pruned_node_num_params(self,
node: BaseNode,
input_mask: np.ndarray,
output_mask: np.ndarray,
fw_info: FrameworkInfo,
include_padded_channels: bool):
"""
Calculates the number of parameters in a pruned node of a Keras model.
Args:
node: The node whose parameters are to be counted.
input_mask: Mask to be applied to the input channels.
output_mask: Mask to be applied to the output channels.
fw_info: Framework-specific information object.
include_padded_channels: Boolean flag to include or exclude null channels in the count.
Returns:
Integer representing the number of parameters in the pruned node.
"""
return get_keras_pruned_node_num_params(node, input_mask, output_mask, fw_info, include_padded_channels)

0 comments on commit e64047e

Please sign in to comment.