From 5306a8d5fe2e6215bea20f348c15f81463c73ed9 Mon Sep 17 00:00:00 2001 From: Reuven <44209964+reuvenperetz@users.noreply.github.com> Date: Thu, 28 Dec 2023 14:08:49 +0200 Subject: [PATCH] Keras structured SIMD pruning (#871) This commit introduces structured and hardware-aware model pruning to MCT. It optimizes models for specific hardware architectures by considering their SIMD capabilities. This approach prunes groups of channels (SIMD groups) to efficiently reduce model size and complexity while aligning with the hardware's SIMD structure. Added functionalities include calculating pruning masks, handling SIMD group-based pruning, and updating model architecture according to pruning sections. Unit tests and example usage in a tutorial notebook are included for demonstration and validation. --------- Co-authored-by: reuvenp --- .github/workflows/run_keras_tests.yml | 1 + README.md | 25 +- .../api/experimental_api_docs/index.rst | 1 + .../methods/keras_pruning_experimental.rst | 25 + docsrc/source/index.rst | 1 + model_compression_toolkit/__init__.py | 1 + model_compression_toolkit/constants.py | 4 + .../core/common/graph/base_graph.py | 114 +++++ .../core/common/graph/base_node.py | 30 +- .../core/common/pruning/__init__.py | 16 + .../core/common/pruning/channels_grouping.py | 93 ++++ .../common/pruning/greedy_mask_calculator.py | 147 ++++++ .../pruning/importance_metrics/__init__.py | 15 + .../base_importance_metric.py | 43 ++ .../importance_metric_factory.py | 39 ++ .../lfh_importance_metric.py | 285 ++++++++++++ .../core/common/pruning/mask/__init__.py | 14 + .../common/pruning/mask/per_channel_mask.py | 114 +++++ .../pruning/mask/per_simd_group_mask.py | 122 +++++ .../core/common/pruning/memory_calculator.py | 366 +++++++++++++++ .../core/common/pruning/prune_graph.py | 74 +++ .../core/common/pruning/pruner.py | 132 ++++++ .../core/common/pruning/pruning_config.py | 73 +++ .../pruning_framework_implementation.py | 156 +++++++ .../core/common/pruning/pruning_info.py | 97 ++++ .../core/common/pruning/pruning_section.py | 127 +++++ .../core/keras/pruning/__init__.py | 15 + .../pruning/pruning_keras_implementation.py | 285 ++++++++++++ model_compression_toolkit/pruning/__init__.py | 19 + .../pruning/keras/__init__.py | 15 + .../pruning/keras/pruning_facade.py | 148 ++++++ .../target_platform/op_quantization_config.py | 3 +- .../target_platform/target_platform_model.py | 13 + .../target_platform_capabilities.py | 9 + .../tpc_models/imx500_tpc/v1/tp_model.py | 2 + tests/keras_tests/pruning_tests/__init__.py | 15 + .../feature_networks/__init__.py | 14 + .../constant_importance_metric.py | 83 ++++ .../networks_tests/__init__.py | 14 + .../conv2d_conv2dtranspose_pruning_test.py | 98 ++++ .../networks_tests/conv2d_pruning_test.py | 98 ++++ .../conv2dtranspose_conv2d_pruning_test.py | 99 ++++ .../conv2dtranspose_pruning_test.py | 103 ++++ .../networks_tests/dense_pruning_test.py | 100 ++++ .../pruning_keras_feature_test.py | 73 +++ .../test_pruning_feature_networks.py | 112 +++++ .../pruning_tests/random_importance_metric.py | 63 +++ .../pruning_tests/test_memory_calculator.py | 87 ++++ .../pruning_tests/test_pretrained_models.py | 189 ++++++++ .../pruning_tests/test_pruning_info.py | 58 +++ tests/test_suite.py | 7 +- tutorials/notebooks/example_keras_pruning.py | 98 ++++ .../example_keras_pruning_mnist.ipynb | 439 ++++++++++++++++++ 53 files changed, 4365 insertions(+), 9 deletions(-) create mode 100644 docsrc/source/api/experimental_api_docs/methods/keras_pruning_experimental.rst create mode 100644 model_compression_toolkit/core/common/pruning/__init__.py create mode 100644 model_compression_toolkit/core/common/pruning/channels_grouping.py create mode 100644 model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py create mode 100644 model_compression_toolkit/core/common/pruning/importance_metrics/__init__.py create mode 100644 model_compression_toolkit/core/common/pruning/importance_metrics/base_importance_metric.py create mode 100644 model_compression_toolkit/core/common/pruning/importance_metrics/importance_metric_factory.py create mode 100644 model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py create mode 100644 model_compression_toolkit/core/common/pruning/mask/__init__.py create mode 100644 model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py create mode 100644 model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py create mode 100644 model_compression_toolkit/core/common/pruning/memory_calculator.py create mode 100644 model_compression_toolkit/core/common/pruning/prune_graph.py create mode 100644 model_compression_toolkit/core/common/pruning/pruner.py create mode 100644 model_compression_toolkit/core/common/pruning/pruning_config.py create mode 100644 model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py create mode 100644 model_compression_toolkit/core/common/pruning/pruning_info.py create mode 100644 model_compression_toolkit/core/common/pruning/pruning_section.py create mode 100644 model_compression_toolkit/core/keras/pruning/__init__.py create mode 100644 model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py create mode 100644 model_compression_toolkit/pruning/__init__.py create mode 100644 model_compression_toolkit/pruning/keras/__init__.py create mode 100644 model_compression_toolkit/pruning/keras/pruning_facade.py create mode 100644 tests/keras_tests/pruning_tests/__init__.py create mode 100644 tests/keras_tests/pruning_tests/feature_networks/__init__.py create mode 100644 tests/keras_tests/pruning_tests/feature_networks/constant_importance_metric.py create mode 100644 tests/keras_tests/pruning_tests/feature_networks/networks_tests/__init__.py create mode 100644 tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2d_conv2dtranspose_pruning_test.py create mode 100644 tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2d_pruning_test.py create mode 100644 tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2dtranspose_conv2d_pruning_test.py create mode 100644 tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2dtranspose_pruning_test.py create mode 100644 tests/keras_tests/pruning_tests/feature_networks/networks_tests/dense_pruning_test.py create mode 100644 tests/keras_tests/pruning_tests/feature_networks/pruning_keras_feature_test.py create mode 100644 tests/keras_tests/pruning_tests/feature_networks/test_pruning_feature_networks.py create mode 100644 tests/keras_tests/pruning_tests/random_importance_metric.py create mode 100644 tests/keras_tests/pruning_tests/test_memory_calculator.py create mode 100644 tests/keras_tests/pruning_tests/test_pretrained_models.py create mode 100644 tests/keras_tests/pruning_tests/test_pruning_info.py create mode 100644 tutorials/notebooks/example_keras_pruning.py create mode 100644 tutorials/notebooks/example_keras_pruning_mnist.ipynb diff --git a/.github/workflows/run_keras_tests.yml b/.github/workflows/run_keras_tests.yml index 35324eeb5..91a2565d2 100644 --- a/.github/workflows/run_keras_tests.yml +++ b/.github/workflows/run_keras_tests.yml @@ -29,6 +29,7 @@ jobs: # CPU environment (https://github.com/tensorflow/tensorflow/issues/41718). # For this reason, if we run them in such an environment, we need to run them first non-parallel separately. run: | + python -m unittest discover tests/keras_tests/pruning_tests -v python -m unittest discover tests/keras_tests/non_parallel_tests -v for script in tests/keras_tests/exporter_tests tests/keras_tests/feature_networks_tests tests/keras_tests/graph_tests tests/keras_tests/layer_tests; do python -m unittest discover $script -v & pids+=($!); done; for pid in ${pids[@]}; do wait $pid || exit 1; done diff --git a/README.md b/README.md index b14410f2a..b0869675f 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,29 @@ In the following table we present the ImageNet validation results for these mode For more results, please refer to [quick start](https://github.com/sony/model_optimization/tree/main/tutorials/quick_start). +### Structured Pruning +MCT introduces a structured and hardware-aware model pruning. +This pruning technique is designed to compress models for specific hardware architectures, +taking into account the target platform's Single Instruction, Multiple Data (SIMD) capabilities. +By pruning groups of channels (SIMD groups), our approach not only reduces model size +and complexity, but ensures that better utilization of channels is in line with the SIMD architecture +for a target KPI of weights memory footprint. + + +_Note: Currently, only Keras models pruning is supported._ + +#### Results + +Results for applying pruning to reduce the parameters of the following models by 50%: + +| Model | Dense Model Accuracy | Pruned Model Accuracy | +|-----------------|----------------------|-----------------------| +| ResNet50 [2] | 75.1 | 72.4 | +| DenseNet121 [2] | 75.0 | 71.15 | + + + + ## Contributions MCT aims at keeping a more up-to-date fork and welcomes contributions from anyone. @@ -153,7 +176,7 @@ MCT aims at keeping a more up-to-date fork and welcomes contributions from anyon [1] Habi, H.V., Peretz, R., Cohen, E., Dikstein, L., Dror, O., Diamant, I., Jennings, R.H. and Netzer, A., 2021. [HPTQ: Hardware-Friendly Post Training Quantization. arXiv preprint](https://arxiv.org/abs/2109.09113). -[2] [MobilNet](https://keras.io/api/applications/mobilenet/#mobilenet-function) from Keras applications. +[2] [Keras Applications](https://keras.io/api/applications/) [3] [TORCHVISION.MODELS](https://pytorch.org/vision/stable/models.html) diff --git a/docsrc/source/api/experimental_api_docs/index.rst b/docsrc/source/api/experimental_api_docs/index.rst index 1df126747..e608daf13 100644 --- a/docsrc/source/api/experimental_api_docs/index.rst +++ b/docsrc/source/api/experimental_api_docs/index.rst @@ -38,6 +38,7 @@ Functions - :ref:`get_tensorflow_data_generation_config`: A function to generate a DataGenerationConfig for Tensorflow data generation(experimental). - :ref:`pytorch_data_generation_experimental`: A function to generate data for a Pytorch model (experimental). - :ref:`get_pytorch_data_generation_config`: A function to load a DataGenerationConfig for Pytorch data generation (experimental). +- :ref:`keras_pruning_experimental`: A function to apply structured pruning for Keras models (experimental). Modules diff --git a/docsrc/source/api/experimental_api_docs/methods/keras_pruning_experimental.rst b/docsrc/source/api/experimental_api_docs/methods/keras_pruning_experimental.rst new file mode 100644 index 000000000..787bb2452 --- /dev/null +++ b/docsrc/source/api/experimental_api_docs/methods/keras_pruning_experimental.rst @@ -0,0 +1,25 @@ +:orphan: + +.. _ug-keras_pruning_experimental: + + +================================================ +Keras Structured Pruning +================================================ + +.. autofunction:: model_compression_toolkit.pruning.keras_pruning_experimental + +================================================ +Pruning Configuration +================================================ + +.. autofunction:: model_compression_toolkit.pruning.PruningConfig + + + +================================================ +Pruning Information +================================================ + +.. autofunction:: model_compression_toolkit.pruning.PruningInfo + diff --git a/docsrc/source/index.rst b/docsrc/source/index.rst index 96382c25c..bd91fe81f 100644 --- a/docsrc/source/index.rst +++ b/docsrc/source/index.rst @@ -57,6 +57,7 @@ Keras: * :ref:`Mixed-precision post training quantization` * :ref:`Init model for Quantization Aware Training` (Experimental) * :ref:`Finalize model after Quantization Aware Training` (Experimental) +* :ref:`Structured Pruning` (Experimental) Pytorch: diff --git a/model_compression_toolkit/__init__.py b/model_compression_toolkit/__init__.py index 937739bae..77a60b7aa 100644 --- a/model_compression_toolkit/__init__.py +++ b/model_compression_toolkit/__init__.py @@ -25,6 +25,7 @@ from model_compression_toolkit import exporter from model_compression_toolkit import gptq from model_compression_toolkit import data_generation +from model_compression_toolkit import pruning from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model diff --git a/model_compression_toolkit/constants.py b/model_compression_toolkit/constants.py index ba9b07a77..93e821fac 100644 --- a/model_compression_toolkit/constants.py +++ b/model_compression_toolkit/constants.py @@ -29,6 +29,7 @@ MIN_THRESHOLD = (2 ** -16) EPS = 1e-8 LUT_VALUES_BITWIDTH = 8 +FP32_BYTES_PER_PARAMETER = 4. # Quantization attributes: OUTPUT_SCALE = 'output_scale' @@ -127,3 +128,6 @@ HESSIAN_OUTPUT_ALPHA = 0.3 HESSIAN_NUM_ITERATIONS = 50 HESSIAN_EPS = 1e-6 + +# Pruning constants +PRUNING_NUM_SCORE_APPROXIMATIONS = 32 \ No newline at end of file diff --git a/model_compression_toolkit/core/common/graph/base_graph.py b/model_compression_toolkit/core/common/graph/base_graph.py index b18b767fe..bd7b84ba9 100644 --- a/model_compression_toolkit/core/common/graph/base_graph.py +++ b/model_compression_toolkit/core/common/graph/base_graph.py @@ -29,6 +29,7 @@ from model_compression_toolkit.core.common.graph.base_node import BaseNode from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector from model_compression_toolkit.core.common.collectors.statistics_collector import scale_statistics, shift_statistics +from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection from model_compression_toolkit.core.common.user_info import UserInformation from model_compression_toolkit.logger import Logger from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \ @@ -726,3 +727,116 @@ def replace_node(self, node_to_replace: BaseNode, new_node: BaseNode): self.replace_output_node(node_to_replace, new_node) self.replace_input_node(node_to_replace, new_node) self.remove_node(node_to_replace) + + def get_pruning_sections(self, + fw_impl: Any) -> List[PruningSection]: + """ + Constructs pruning sections for a given computational graph. + Each section is created starting from an entry node and includes intermediate and exit nodes. + + Args: + fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning. + + Returns: List of PruningSection in the graph. + """ + entry_nodes = self.get_pruning_sections_entry_nodes(fw_impl) + return [self._create_pruning_section(entry_node, fw_impl) for entry_node in entry_nodes] + + def get_pruning_sections_entry_nodes(self, fw_impl: Any) -> List[BaseNode]: + """ + Identifies entry nodes for pruning sections within the graph. + Traverses the graph in a topological order, checking each node for prunability criteria. + Returns a list of nodes that mark the beginning of a prunable section in the graph. + + Args: + fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning. + + Returns: List of nodes that are entry nodes in the pruning sections of the graph. + + """ + prunable_nodes = [] + for n in list(topological_sort(self)): + if fw_impl.is_node_entry_node(n) and self._is_node_topology_prunable(n, fw_impl): + prunable_nodes.append(n) + return prunable_nodes + + def _is_node_topology_prunable(self, entry_node: BaseNode, fw_impl: Any) -> bool: + """ + Determines if the topology starting from a given entry node is suitable for pruning. + Iteratively examines the graph structure, focusing on node connectivity and pruning criteria. + Returns True if the topology is prunable, False otherwise. + + Args: + entry_node (BaseNode): The node to start the topology check from. + fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning. + + Returns: Whether this node is a start of a pruning section according to the graph topology or not. + + """ + next_node = entry_node + + # Continue iterating until the conditions for prunability are no longer met + while len(self.out_edges(next_node)) == 1: + next_node = self.out_edges(next_node)[0].sink_node + + # If next_node is an exit node and has only one incoming edge, the topology is prunable. + if fw_impl.is_node_exit_node(next_node, entry_node, self.fw_info) and len(self.in_edges(next_node)) == 1: + return True + + # If the next node is not an intermediate node or has more than one incoming/outgoing edge, + # stop the check. + if not fw_impl.is_node_intermediate_pruning_section(next_node) or len(self.in_edges(next_node)) != 1 or len(self.out_edges(next_node)) != 1: + return False + + # If the loop exits normally, it implies that the topology is not prunable + return False + + + def _create_pruning_section(self, entry_node: BaseNode, fw_impl: Any) -> PruningSection: + """ + Creates a PruningSection object starting from a given entry node. + Includes logic to find intermediate and exit nodes to complete the section. + Ensures the provided entry node is a valid starting point for pruning. + + Args: + entry_node (BaseNode): The entry node to create the section it starts. + fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning. + + Returns: The pruning section that starts with node entry_node. + + """ + if not fw_impl.is_node_entry_node(entry_node): + Logger.error(f"Expected to find an entry node to create its pruning section," + f"but node {entry_node} is not an entry node.") + + intermediate_nodes, exit_node = self._find_intermediate_and_exit_nodes(entry_node, fw_impl) + + if not fw_impl.is_node_exit_node(exit_node, entry_node, self.fw_info): + Logger.error(f"Expected to find exit node when creating a pruning section," + f"but node {exit_node} is not an exit node.") + + return PruningSection(entry_node=entry_node, + intermediate_nodes=intermediate_nodes, + exit_node=exit_node) + + def _find_intermediate_and_exit_nodes(self, entry_node: BaseNode, fw_impl: Any) -> Tuple[List[BaseNode], BaseNode]: + """ + Identifies intermediate and exit nodes for a pruning section starting from an entry node. + Iterates through connected nodes to build the complete structure of the pruning section. + + Args: + entry_node (BaseNode): An entry node to find the intermediate and exit nodes of its section. + fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning. + + Returns: A tuple containing a list of intermediate nodes and the exit node. + + """ + intermediate_nodes = [] + next_node = self.out_edges(entry_node)[0].sink_node + while not fw_impl.is_node_exit_node(next_node, entry_node, self.fw_info): + intermediate_nodes.append(next_node) + next_node = self.out_edges(next_node)[0].sink_node + + return intermediate_nodes, next_node + + diff --git a/model_compression_toolkit/core/common/graph/base_node.py b/model_compression_toolkit/core/common/graph/base_node.py index 1feceafad..4b8c662cd 100644 --- a/model_compression_toolkit/core/common/graph/base_node.py +++ b/model_compression_toolkit/core/common/graph/base_node.py @@ -19,7 +19,7 @@ import numpy as np from model_compression_toolkit.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE, \ - ACTIVATION_NBITS_ATTRIBUTE + ACTIVATION_NBITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER from model_compression_toolkit.logger import Logger from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationConfigOptions, \ TargetPlatformCapabilities, LayerFilterParams @@ -222,9 +222,9 @@ def get_memory_bytes(self, fw_info) -> float: """ q_params, f_params = self.get_num_parameters(fw_info) if self.final_weights_quantization_cfg is None: # float coefficients - memory = (f_params+q_params) * 4 + memory = (f_params+q_params) * FP32_BYTES_PER_PARAMETER else: - memory = (f_params*4)+ (q_params * self.final_weights_quantization_cfg.weights_n_bits / 8) # in bytes + memory = (f_params * FP32_BYTES_PER_PARAMETER) + (q_params * self.final_weights_quantization_cfg.weights_n_bits / 8) # in bytes return memory @@ -239,7 +239,7 @@ def get_float_memory_bytes(self, fw_info) -> float: """ q_params, f_params = self.get_num_parameters(fw_info) - return (f_params + q_params) * 32 / 8 # in bytes + return (f_params + q_params) * FP32_BYTES_PER_PARAMETER def get_unified_weights_candidates_dict(self): """ @@ -499,4 +499,24 @@ def is_match_filter_params(self, layer_filter_params: LayerFilterParams) -> bool if not c.match(layer_config): return False - return True \ No newline at end of file + return True + + def get_simd(self) -> int: + """ + Retrieves the SIMD size used for this node. It collects the SIMD sizes from all candidate + configurations and returns the minimum SIMD size. + + Returns: + int: The node's SIMD size. + + """ + simd_list = [qc.weights_quantization_cfg.simd_size for qc in self.candidates_quantization_cfg] + if len(simd_list) > 1: + Logger.warning(f"More than one pruning SIMD option is available." + f" Min SIMD is used: {min(simd_list)}") + if len(simd_list) == 0: + Logger.error(f"No SIMD option is available for {self}") + _simd = min(simd_list) + if _simd <= 0 or int(_simd) != _simd: + Logger.error(f"SIMD is expected to be a non-positive integer but found: {_simd}") + return _simd diff --git a/model_compression_toolkit/core/common/pruning/__init__.py b/model_compression_toolkit/core/common/pruning/__init__.py new file mode 100644 index 000000000..8da969d03 --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + diff --git a/model_compression_toolkit/core/common/pruning/channels_grouping.py b/model_compression_toolkit/core/common/pruning/channels_grouping.py new file mode 100644 index 000000000..36c5e4617 --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/channels_grouping.py @@ -0,0 +1,93 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import List, Dict, Tuple + +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common import BaseNode +import numpy as np + + +class ChannelGrouping: + """ + ChannelGrouping handles the sorting and grouping of channel indices for prunable nodes in a graph, + based on their importance scores and SIMD group sizes. + """ + + def __init__(self, + prunable_nodes: List[BaseNode], + fw_info: FrameworkInfo): + """ + Initializes the ChannelGrouping with necessary information. + + Args: + prunable_nodes: List of nodes that can be pruned. + fw_info: Framework-specific information and utilities. + """ + self.prunable_nodes = prunable_nodes + self.fw_info = fw_info + # Store for each node a list of numpy arrays. Each numpy array represents the + # indices of the channels in an SIMD group. + self._simd_groups_indices = {} + + @property + def simd_groups_indices(self) -> Dict[BaseNode, List[np.ndarray]]: + """ + Returns the grouped indices for each prunable node. + + Returns: + Dict[BaseNode, List[np.ndarray]]: Grouped indices for each node. + """ + return self._simd_groups_indices + + def group_scores_by_simd_groups(self, + score_by_node: Dict[BaseNode, np.ndarray]): + """ + Groups importance scores of each prunable node by their respective SIMD group sizes. + This function processes the importance scores of each prunable node and divides them into + groups based on the SIMD width of the node. Grouping scores by SIMD size helps in identifying + which groups of channels can be pruned together based on their collective importance. + + Args: + score_by_node: A dictionary mapping nodes to their importance scores. + """ + for prunable_node, node_scores in score_by_node.items(): + self._simd_groups_indices[prunable_node] = self._group_node_scores(node_scores, + prunable_node.get_simd()) + + def _group_node_scores(self, + scores: np.ndarray, + simd: int) -> List[np.ndarray]: + """ + Groups the scores and their corresponding indices based on SIMD size. + + Args: + scores: An array of scores to be grouped. + simd: Size of the SIMD group. + + Returns: + Tuple[List[np.ndarray], List[np.ndarray]]: Grouped scores and indices. + """ + sorted_indices = np.argsort(-scores) + num_complete_groups = len(scores) // simd + scores_groups = [scores[sorted_indices[i * simd:(i + 1) * simd]] for i in range(num_complete_groups)] + indices_groups = [sorted_indices[i * simd:(i + 1) * simd] for i in range(num_complete_groups)] + remainder = len(scores) % simd + if remainder != 0: + scores_groups.append(scores[sorted_indices[-remainder:]]) + indices_groups.append(sorted_indices[-remainder:]) + return indices_groups + + diff --git a/model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py b/model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py new file mode 100644 index 000000000..d63c70e5c --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py @@ -0,0 +1,147 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +from typing import List, Dict, Tuple + +from model_compression_toolkit.core.common import BaseNode, Graph +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI +from model_compression_toolkit.core.common.pruning.mask.per_channel_mask import MaskIndicator +from model_compression_toolkit.core.common.pruning.memory_calculator import MemoryCalculator +from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation +from model_compression_toolkit.core.common.pruning.mask.per_simd_group_mask import PerSIMDGroupMask +from model_compression_toolkit.logger import Logger +from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities + + +class GreedyMaskCalculator: + """ + GreedyMaskCalculator calculates pruning masks for prunable nodes to meet a + specified target KPI. It employs a greedy approach to selectively unprune channel + groups (SIMD groups) based on their importance scores. Initially, all channels are + pruned (mask set to zero), and the calculator iteratively adds back the most significant + channel groups until the memory footprint meets the target KPI or all channels are unpruned. + """ + def __init__(self, + prunable_nodes: List[BaseNode], + fw_info: FrameworkInfo, + simd_groups_scores: Dict[BaseNode, np.ndarray], + target_kpi: KPI, + graph: Graph, + fw_impl: PruningFrameworkImplementation, + tpc: TargetPlatformCapabilities, + simd_groups_indices: Dict[BaseNode, List[List[int]]]): + """ + Args: + prunable_nodes (List[BaseNode]): Nodes that are eligible for pruning. + fw_info (FrameworkInfo): Framework-specific information and utilities. + simd_groups_scores (Dict[BaseNode, np.ndarray]): Importance scores for each SIMG group in a prunable node. + target_kpi (KPI): The target KPI to achieve. + graph (Graph): The computational graph of the model. + fw_impl (PruningFrameworkImplementation): Framework-specific implementation details. + tpc (TargetPlatformCapabilities): Platform-specific constraints and capabilities. + simd_groups_indices (Dict[BaseNode, List[List[int]]]): Indices of SIMD groups in each node. + """ + self.prunable_nodes = prunable_nodes + self.fw_info = fw_info + self.target_kpi = target_kpi + self.graph = graph + self.fw_impl = fw_impl + self.tpc = tpc + + self.simd_groups_indices = simd_groups_indices + self.simd_groups_scores = simd_groups_scores + + self.oc_pruning_mask = PerSIMDGroupMask(prunable_nodes=prunable_nodes, + fw_info=fw_info, + simd_groups_indices=simd_groups_indices) + + self.memory_calculator = MemoryCalculator(graph=graph, + fw_info=fw_info, + fw_impl=fw_impl) + + + def get_mask(self) -> Dict[BaseNode, np.ndarray]: + """ + Retrieves the current pruning mask for each prunable node. + + Returns: + Dict[BaseNode, np.ndarray]: The current pruning mask for each node. + """ + return self.oc_pruning_mask.get_mask() + + def compute_mask(self): + """ + Computes the pruning mask by iteratively adding SIMD groups to unpruned state + based on their importance and the target KPI. + """ + # Iteratively unprune the graph while monitoring the memory footprint. + current_memory = self.memory_calculator.get_pruned_graph_memory(masks=self.oc_pruning_mask.get_mask(), + include_padded_channels=self.tpc.is_simd_padding) + if current_memory > self.target_kpi.weights_memory: + Logger.error(f"Minimal required memory is {current_memory}, " + f"but target KPI is {self.target_kpi.weights_memory}") + + # Greedily unprune groups (by setting their mask to 1) until the memory target is met + # or all channels unpruned. + while current_memory < self.target_kpi.weights_memory and self.oc_pruning_mask.has_pruned_channel(): + # Select the best SIMD group (best means highest score which means most sensitive group) + # to add based on the scores. + node_to_remain, group_to_remain_idx = self._get_most_sensitive_simd_group_candidate() + self.oc_pruning_mask.set_mask_value_for_simd_group(node=node_to_remain, + group_index=group_to_remain_idx, + mask_indicator=MaskIndicator.REMAINED) + current_memory = self.memory_calculator.get_pruned_graph_memory(masks=self.oc_pruning_mask.get_mask(), + include_padded_channels=self.tpc.is_simd_padding) + + # If the target memory is exceeded, revert the last addition. + if current_memory > self.target_kpi.weights_memory: + self.oc_pruning_mask.set_mask_value_for_simd_group(node=node_to_remain, + group_index=group_to_remain_idx, + mask_indicator=MaskIndicator.PRUNED) + + + + def _get_most_sensitive_simd_group_candidate(self) -> Tuple[BaseNode, int]: + """ + Identifies the most sensitive SIMD group for pruning based on the importance scores. + + Returns: + Tuple[BaseNode, int]: The node and group index of the most sensitive SIMD group. + """ + + best_score = -np.inf + best_node = None + best_group_idx = -1 + + for node, mask in self.oc_pruning_mask.get_mask_simd().items(): + # Get the index of the first zero in the mask. A zero indicates a prunable channel group. + group_idx = int(np.argmax(mask == 0)) + + # If group_idx is 0, it means there are no zeros in the mask, so this group is not prunable. + if group_idx != 0: + score = self.simd_groups_scores[node][group_idx] + # If the score for this group is better than the best score found so far, update the best score. + if score > best_score: + best_score = score + best_node = node + best_group_idx = group_idx + + if best_node is None: + Logger.error("No prunable SIMD group found.") + + return best_node, best_group_idx + diff --git a/model_compression_toolkit/core/common/pruning/importance_metrics/__init__.py b/model_compression_toolkit/core/common/pruning/importance_metrics/__init__.py new file mode 100644 index 000000000..cb2075e68 --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/importance_metrics/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + diff --git a/model_compression_toolkit/core/common/pruning/importance_metrics/base_importance_metric.py b/model_compression_toolkit/core/common/pruning/importance_metrics/base_importance_metric.py new file mode 100644 index 000000000..c6290bfcd --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/importance_metrics/base_importance_metric.py @@ -0,0 +1,43 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import List, Tuple, Dict +from abc import abstractmethod, ABC +from model_compression_toolkit.core.common import BaseNode +import numpy as np + + +class BaseImportanceMetric(ABC): + """ + Interface for implementing importance metrics used for pruning SIMD groups. + """ + @abstractmethod + def get_entry_node_to_simd_score(self, entry_nodes: List[BaseNode]) -> Tuple[ + Dict[BaseNode, np.ndarray], Dict[BaseNode, List[np.ndarray]]]: + """ + Compute SIMD scores for each group of channels for a list of entry nodes. + Group the channels into SIMD groups, and compute a score for each SIMD group. + + Args: + entry_nodes (List[BaseNode]): Entry nodes of pruning sections in the graph. + + Returns: + Tuple[Dict, Dict]: Tuple of two dictionaries. The first is a dictionary of entry nodes to + numpy arrays where each element is an importance score for the SIMD group. The second + dictionary maps each node to a list of numpy arrays where each numpy array is the indices + of channels in a group. + """ + raise NotImplemented(f'{self.__class__.__name__} have to implement the ' + f'framework\'s get_entry_node_to_simd_score method.') # pragma: no cover diff --git a/model_compression_toolkit/core/common/pruning/importance_metrics/importance_metric_factory.py b/model_compression_toolkit/core/common/pruning/importance_metrics/importance_metric_factory.py new file mode 100644 index 000000000..d3063a970 --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/importance_metrics/importance_metric_factory.py @@ -0,0 +1,39 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from model_compression_toolkit.core.common.pruning.pruning_config import ImportanceMetric +from model_compression_toolkit.core.common.pruning.importance_metrics.base_importance_metric import BaseImportanceMetric +from model_compression_toolkit.core.common.pruning.importance_metrics.lfh_importance_metric import LFHImportanceMetric + +# A dictionary mapping each importance metric enum to its corresponding class. +IMPORTANCE_METRIC_DICT = {ImportanceMetric.LFH: LFHImportanceMetric} + +def get_importance_metric(im: ImportanceMetric, **kwargs) -> BaseImportanceMetric: + """ + Retrieves an instance of the importance metric class based on the specified importance metric enum. + + Args: + im (ImportanceMetric): An enum value representing the desired importance metric. + **kwargs: Additional keyword arguments to be passed to the constructor of the importance metric class. + + Returns: + BaseImportanceMetric: An instance of a class derived from BaseImportanceMetric corresponding to the provided enum. + """ + # Retrieve the corresponding class for the provided importance metric enum from the dictionary. + im = IMPORTANCE_METRIC_DICT.get(im) + + # Create and return an instance of the importance metric class with the provided keyword arguments. + return im(**kwargs) + diff --git a/model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py b/model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py new file mode 100644 index 000000000..701297c64 --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py @@ -0,0 +1,285 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import numpy as np +from typing import Callable, List, Dict, Tuple + +from model_compression_toolkit.core.common import Graph, BaseNode +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianMode, HessianInfoGranularity, \ + TraceHessianRequest +from model_compression_toolkit.core.common.pruning.channels_grouping import ChannelGrouping +from model_compression_toolkit.core.common.pruning.importance_metrics.base_importance_metric import BaseImportanceMetric +from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig +from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation +from model_compression_toolkit.logger import Logger + + +class LFHImportanceMetric(BaseImportanceMetric): + """ + LFHImportanceMetric implements an importance metric based on the Hessian of the + loss function w.r.t weights of each SIMD group. + """ + + def __init__(self, + graph: Graph, + representative_data_gen: Callable, + fw_impl: PruningFrameworkImplementation, + pruning_config: PruningConfig, + fw_info: FrameworkInfo): + """ + Initialize the LFHImportanceMetric instance. + + Args: + graph (Graph): Computational graph of the model. + representative_data_gen (Callable): Function to generate representative data. + fw_impl (PruningFrameworkImplementation): Implementation of pruning for the framework. + pruning_config (PruningConfig): Configuration for pruning. + fw_info (FrameworkInfo): Framework-specific information. + """ + self.float_graph = graph + self.representative_data_gen = representative_data_gen + self.fw_impl = fw_impl + self.pruning_config = pruning_config + self.fw_info = fw_info + + # Initialize internal dictionaries for storing intermediate computations. + self._entry_node_to_hessian_score = {} + self._entry_node_count_oc_nparams = {} + self._entry_node_to_simd_score = {} + + def get_entry_node_to_simd_score(self, entry_nodes: List[BaseNode]) -> Tuple[Dict[BaseNode, np.ndarray], Dict[BaseNode, List[np.ndarray]]]: + """ + Compute SIMD scores for each group of channels for a list of entry nodes. + The function first compute a score for each channel in the node. Then, and based on the scores + computed, we group the channels into SIMD groups (by simply sorting the scores and grouping them). + Eventually, we compute a score for each group of channels using LFH score, squared L2-norm, + and number of parameters in a group. + + Ci_Score = Trace(H_Ci) * SqL2Norm(Ci) / |Ci| + + Where Trace(H_Ci) is the trace of the hessian of the loss function (w.r.t weights in Ci), + SqL2Norm is squared l2-norm of the weights in Ci, and |Ci| is the number of parameters in Ci. + + Args: + entry_nodes (List[BaseNode]): Entry nodes in the graph. + + Returns: + Tuple[Dict, Dict]: Tuple of dictionaries containing SIMD scores and grouped indices. + """ + + # Compute initial scores for entry nodes. + entry_node_to_score = self._get_entry_node_to_score(entry_nodes) + + # Group indices based on SIMD configurations. + grouped_indices = self._compute_simd_groups_indices(entry_node_to_score) + + # Compute squared L2 norms for the groups. + _squared_l2_norm_by_groups = self._get_squaredl2norm(entry_nodes, grouped_indices) + + # Initialize dictionary for storing SIMD scores. + entry_node_to_simd_score = {} + + # Compute SIMD scores for each group. + for node, hessian_score in self._entry_node_to_hessian_score.items(): + group_hessian_score = [np.sum(hessian_score[g]) for g in grouped_indices[node]] + nparams_by_group = np.asarray([np.sum(self._entry_node_count_oc_nparams[node][g]) for g in grouped_indices[node]]) + entry_node_to_simd_score[node] = np.asarray(group_hessian_score) * _squared_l2_norm_by_groups[node] / nparams_by_group + + return entry_node_to_simd_score, grouped_indices + + def _get_entry_node_to_score(self, entry_nodes: List[BaseNode]) -> Dict[BaseNode, np.ndarray]: + """ + Compute score for each channel for a list of entry nodes. + We compute a score for each channel using LFH score, squared L2-norm, + and number of parameters in the channel. + + Ci_Score = Trace(H_Ci) * SqL2Norm(Ci) / |Ci| + + Where Trace(H_Ci) is the trace of the hessian of the loss function (w.r.t weights in Ci), + SqL2Norm is squared l2-norm of the weights in Ci, and |Ci| is the number of parameters in Ci. + + Args: + entry_nodes (List[BaseNode]): Entry nodes of pruning sections in the graph. + + Returns: + Dict[BaseNode, np.ndarray]: Dictionary containing channel scores for each entry node. + """ + + # Initialize HessianInfoService for score computation. + hessian_info_service = HessianInfoService(graph=self.float_graph, + representative_dataset=self.representative_data_gen, + fw_impl=self.fw_impl) + + # Fetch and process Hessian scores for output channels of entry nodes. + nodes_scores = [] + for node in entry_nodes: + _request = TraceHessianRequest(mode=HessianMode.WEIGHTS, + granularity=HessianInfoGranularity.PER_OUTPUT_CHANNEL, + target_node=node) + _scores_for_node = hessian_info_service.fetch_hessian(_request, + required_size=self.pruning_config.num_score_approximations) + nodes_scores.append(_scores_for_node) + + # Average and map scores to nodes. + self._entry_node_to_hessian_score = {node: np.mean(scores, axis=0) for node, scores in zip(entry_nodes, nodes_scores)} + + self._entry_node_count_oc_nparams = self._count_oc_nparams(entry_nodes=entry_nodes) + _entry_node_l2_oc_norm = self._get_squaredl2norm(entry_nodes=entry_nodes) + + # Normalize scores using squared L2 norms and number of parameters. + _entry_node_to_score = self._normalize_lfh_scores(_entry_node_l2_oc_norm) + return _entry_node_to_score + + def _compute_simd_groups_indices(self, + entry_node_to_score: Dict[BaseNode, np.ndarray]) -> Dict[BaseNode, List[np.ndarray]]: + """ + Compute SIMD group indices for each entry node. + + Args: + entry_node_to_score (Dict[BaseNode, np.ndarray]): Scores for entry nodes. + + Returns: + Dict[BaseNode, List[np.ndarray]]: Dictionary of entry nodes mapped to their SIMD group indices. + """ + # Initialize channel grouping utility. + channel_grouping = ChannelGrouping(prunable_nodes=list(entry_node_to_score.keys()), + fw_info=self.fw_info) + + channel_grouping.group_scores_by_simd_groups(entry_node_to_score) + grouped_indices = channel_grouping.simd_groups_indices + + return grouped_indices + + def _normalize_lfh_scores(self, + entry_node_to_squaredl2norm: Dict[BaseNode, np.ndarray]) -> Dict[BaseNode, np.ndarray]: + """ + Normalizes the LFH scores using the squared L2 norms. + + Args: + entry_node_to_squaredl2norm (Dict[BaseNode, np.ndarray]): Squared L2 norms for each entry node. + + Returns: + Dict[BaseNode, np.ndarray]: Normalized LFH scores for each entry node. + """ + new_scores = {} + for node, hessian_score_vector in self._entry_node_to_hessian_score.items(): + # Normalize the hessian score vector using squared L2 norm and the count of output channel parameters. + new_scores[node] = hessian_score_vector * entry_node_to_squaredl2norm[node] / self._entry_node_count_oc_nparams[node] + return new_scores + + def _count_oc_nparams(self, entry_nodes: List[BaseNode]) -> Dict[BaseNode, np.ndarray]: + """ + Counts the number of parameters per output channel for each entry node. + + Args: + entry_nodes (List[BaseNode]): List of entry nodes to count parameters for. + + Returns: + Dict[BaseNode, np.ndarray]: Dictionary of nodes and their parameters count per output channel. + """ + node_channel_params = {} + for entry_node in entry_nodes: + kernel_attr, num_oc, oc_axis = self._get_kernel_node_oc_info(entry_node) + kernel = entry_node.get_weights_by_keys(kernel_attr) + + # Calculate parameters per channel + params_per_channel = np.prod(kernel.shape) / kernel.shape[oc_axis] + # Create an array filled with the count of parameters per output channel. + num_params_array = np.full(kernel.shape[oc_axis], params_per_channel) + + # Map each node to its array of parameters count per output channel. + node_channel_params[entry_node] = num_params_array + return node_channel_params + + def _get_squaredl2norm(self, + entry_nodes: List[BaseNode], + grouped_indices: Dict[BaseNode, List[np.ndarray]] = None) -> Dict[BaseNode, np.ndarray]: + """ + Computes the squared L2 norm for each output channel (or group of channels) of the entry nodes. + + Args: + entry_nodes (List[BaseNode]): List of entry nodes for L2 norm computation. + grouped_indices (Dict[BaseNode, List[List[int]]], optional): Indices of channel groups. Defaults to None. + + Returns: + Dict[BaseNode, np.ndarray]: Dictionary of nodes and their squared L2 norms for each output channel (or group). + """ + node_l2_channel_norm = {} + for entry_node in entry_nodes: + kernel_attr, num_oc, oc_axis = self._get_kernel_node_oc_info(entry_node) + # Retrieve the kernel tensor of the node. + kernel = entry_node.get_weights_by_keys(kernel_attr) + # Split the kernel tensor into individual channels (or groups if provided). + channels = np.split(kernel, indices_or_sections=num_oc, axis=oc_axis) + + # If grouped_indices are provided, concatenate tensors based on grouped indices. + if grouped_indices: + concatenated_tensors = self._concatenate_tensors_by_indices(channels, grouped_indices[entry_node]) + channels = concatenated_tensors + + # Compute the squared L2 norm for each channel (or group). + l2_norms = np.asarray([np.linalg.norm(c.flatten(), ord=2) ** 2 for c in channels]) + node_l2_channel_norm[entry_node] = l2_norms + + return node_l2_channel_norm + + def _get_kernel_node_oc_info(self, entry_node: BaseNode) -> Tuple[str, int, int]: + """ + Retrieves information about the output channels (oc) for a given kernel node. + + Args: + entry_node (BaseNode): The node whose output channel information is needed. + + Returns: + tuple: A tuple containing the kernel attribute, the number of output channels, and the axis of the output channels. + """ + kernel_attr = self.fw_info.get_kernel_op_attributes(entry_node.type) + # Ensure only one kernel attribute exists for the given node. + if len(kernel_attr) != 1: + Logger.error(f"Expected to found a single attribute but found {len(kernel_attr)} for node {entry_node}") + kernel_attr = kernel_attr[0] + + # Retrieve and validate the axis index for the output channels. + oc_axis, _ = self.fw_info.kernel_channels_mapping.get(entry_node.type) + if oc_axis is None or int(oc_axis) != oc_axis: + Logger.error(f"Expected output channel axis to be an integer but is {oc_axis} for node {entry_node}") + + # Get the number of output channels based on the kernel attribute and axis. + num_oc = entry_node.get_weights_by_keys(kernel_attr[0]).shape[oc_axis] + return kernel_attr, num_oc, oc_axis + + def _concatenate_tensors_by_indices(self, + channels: List[np.ndarray], + index_list: List[np.ndarray]) -> List[np.ndarray]: + """ + Concatenates tensors based on provided indices. + + Args: + channels (List[np.ndarray]): List of channel tensors. + index_list (List[np.ndarray]): Indices of channels to be concatenated. + + Returns: + List[np.ndarray]: List of concatenated tensors. + """ + concatenated_tensors = [] + for index_array in index_list: + # Gather tensors based on indices. + tensors_to_concatenate = [channels[i] for i in index_array] + # Concatenate the gathered tensors. + concatenated_tensor = np.concatenate(tensors_to_concatenate) + concatenated_tensors.append(concatenated_tensor) + return concatenated_tensors diff --git a/model_compression_toolkit/core/common/pruning/mask/__init__.py b/model_compression_toolkit/core/common/pruning/mask/__init__.py new file mode 100644 index 000000000..807f5e384 --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/mask/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== \ No newline at end of file diff --git a/model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py b/model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py new file mode 100644 index 000000000..ef3e5ecff --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py @@ -0,0 +1,114 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from enum import Enum + +import numpy as np +from typing import List, Dict, Tuple + +from model_compression_toolkit.core.common import BaseNode, Graph +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI +from model_compression_toolkit.core.common.pruning.memory_calculator import MemoryCalculator +from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation +from model_compression_toolkit.logger import Logger +from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities + +class MaskIndicator(Enum): + """ + Enum class for indicating the status of channels in a pruning mask. + + PRUNED: Represents channels that are removed or pruned from the model. + REMAINED: Represents channels that are kept or remain unpruned in the model. + """ + PRUNED = 0 + REMAINED = 1 + + + +class PerChannelMask: + def __init__(self, prunable_nodes: List[BaseNode], fw_info: FrameworkInfo): + """ + Initializes the PerChannelMask with prunable nodes and framework information. + This class is responsible for maintaining and updating the pruning masks for each + prunable node in the model. The mask is an array indicating whether each output channel + of a node is pruned (0) or remained (1). + + Args: + prunable_nodes: List of nodes in the model that are subject to pruning. + fw_info: Framework-specific information required for pruning operations. + """ + self.prunable_nodes = prunable_nodes + self.fw_info = fw_info + self._mask = None # Initialize the mask dictionary + self._init_masks() # Call to initialize masks for each prunable node + + def get_mask(self) -> Dict[BaseNode, np.ndarray]: + """ + Retrieves the current pruning masks for all prunable nodes in the model. + + Returns: + A dictionary mapping each prunable node to its corresponding pruning mask. + """ + return self._mask + + def set_mask_value_for_simd_group(self, node: BaseNode, channel_idx: int, mask_indicator: MaskIndicator): + """ + Sets the mask value for a specific channel of a prunable node. + + Args: + node: The prunable node to update the mask for. + channel_idx: The index of the channel to update in the mask. + mask_indicator: The new value to set in the mask (either PRUNED or REMAINED). + """ + if mask_indicator not in [MaskIndicator.PRUNED, MaskIndicator.REMAINED]: + Logger.error("Mask value must be either MaskIndicator.PRUNED or MaskIndicator.REMAINED") + self._mask[node][channel_idx] = mask_indicator.value + + def has_pruned_channel(self) -> bool: + """ + Determines if there is at least one pruned channel across all nodes in the model. + + Returns: + True if there is at least one pruned channel, False otherwise. + """ + return any(MaskIndicator.PRUNED.value in mask for mask in self._mask.values()) + + def _init_masks(self): + """ + Initializes the pruning masks for each prunable node in the model. + Sets the initial mask for each node as an array of zeros (indicating all channels are + initially pruned). + """ + self._mask = {} # Initialize the dictionary for pruning masks. + for prunable_node in self.prunable_nodes: + num_oc = self._compute_num_of_out_channels(prunable_node) # Number of output channels for the node. + layer_mask = np.full(num_oc, MaskIndicator.PRUNED.value) # Initialize the mask with zeros. + self._mask[prunable_node] = layer_mask + + def _compute_num_of_out_channels(self, node: BaseNode) -> int: + """ + Computes the number of output channels for a given node. + + Args: + node (BaseNode): The node whose output channels are to be counted. + + Returns: + int: Number of output channels for the node. + """ + kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)[0] + oc_axis = self.fw_info.kernel_channels_mapping.get(node.type)[0] + num_oc = node.get_weights_by_keys(kernel_attr).shape[oc_axis] + return num_oc + diff --git a/model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py b/model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py new file mode 100644 index 000000000..e8e970eb2 --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py @@ -0,0 +1,122 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +from typing import List, Dict, Tuple + +from model_compression_toolkit.core.common import BaseNode, Graph +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI +from model_compression_toolkit.core.common.pruning.mask.per_channel_mask import PerChannelMask, MaskIndicator +from model_compression_toolkit.core.common.pruning.memory_calculator import MemoryCalculator +from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation +from model_compression_toolkit.logger import Logger +from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities + +class PerSIMDGroupMask: + def __init__(self, + prunable_nodes: List[BaseNode], + fw_info: FrameworkInfo, + simd_groups_indices: Dict[BaseNode, List[List[int]]]): + """ + Initializes a mask calculator for SIMD groups in prunable nodes. + Manages both per-channel and per-SIMD-group masks. + + Args: + prunable_nodes: List of nodes that can be pruned. + fw_info: Framework-specific information. + simd_groups_indices: A dictionary mapping each node to its SIMD groups' indices. + """ + # Initialize the per-channel mask + self.per_channel_mask = PerChannelMask(prunable_nodes=prunable_nodes, fw_info=fw_info) + self.prunable_nodes = prunable_nodes + self.fw_info = fw_info + self.simd_groups_indices = simd_groups_indices + self._mask_simd = None # Initialize the SIMD group mask dictionary + self._init_masks() # Initialize masks for each prunable node + self._update_mandatory_mask() # Ensure at least one SIMD group remains unpruned + + def get_mask_simd(self) -> Dict[BaseNode, np.ndarray]: + """ + Retrieves the current SIMD group masks for all prunable nodes. + + Returns: + A dictionary mapping each prunable node to its corresponding SIMD group mask. + """ + return self._mask_simd + + def get_mask(self) -> Dict[BaseNode, np.ndarray]: + """ + Retrieves the current per-channel masks for all prunable nodes. + + Returns: + A dictionary mapping each prunable node to its corresponding per-channel mask. + """ + return self.per_channel_mask.get_mask() + + def set_mask_value_for_simd_group(self, + node: BaseNode, + group_index: int, + mask_indicator: MaskIndicator): + """ + Sets the mask value for a specific SIMD group of a prunable node. + + Args: + node: The prunable node to update the mask for. + group_index: The index of the SIMD group to update in the mask. + mask_indicator: The new value to set in the mask (either PRUNED or REMAINED). + """ + if mask_indicator not in [MaskIndicator.PRUNED, MaskIndicator.REMAINED]: + Logger.error("Mask value must be either MaskIndicator.PRUNED or MaskIndicator.REMAINED") + + # Update the SIMD group mask and corresponding per-channel mask + self._mask_simd[node][group_index] = mask_indicator.value + node_mask_indices = self.simd_groups_indices[node][group_index] + for idx in node_mask_indices: + self.per_channel_mask.set_mask_value_for_simd_group(node=node, + channel_idx=idx, + mask_indicator=mask_indicator) + def has_pruned_channel(self) -> bool: + """ + Checks if there is at least one channel marked for pruning in any node mask. + + Returns: + True if there is at least one channel to be pruned, False otherwise. + """ + return self.per_channel_mask.has_pruned_channel() + + def _init_masks(self): + """ + Initializes the SIMD group masks for each prunable node. + Sets the initial mask for each node as an array of zeros (indicating + all groups are initially pruned). + """ + self._mask_simd = {} # Initialize the dictionary for SIMD group masks. + for prunable_node in self.prunable_nodes: + num_groups = len(self.simd_groups_indices[prunable_node]) # Number of SIMD groups for the node. + layer_mask_per_simd_group = np.full(num_groups, MaskIndicator.PRUNED.value) # Initialize the mask with zeros. + self._mask_simd[prunable_node] = layer_mask_per_simd_group + + def _update_mandatory_mask(self): + """ + Updates the mandatory masks for each prunable node to ensure at least one SIMD + group remains unpruned. + """ + for prunable_node in self.prunable_nodes: + # Mark the first SIMD group as mandatory (unpruned). + self.set_mask_value_for_simd_group(node=prunable_node, + group_index=0, + mask_indicator=MaskIndicator.REMAINED) + diff --git a/model_compression_toolkit/core/common/pruning/memory_calculator.py b/model_compression_toolkit/core/common/pruning/memory_calculator.py new file mode 100644 index 000000000..50b161afe --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/memory_calculator.py @@ -0,0 +1,366 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +from typing import List, Dict + +from model_compression_toolkit.constants import FP32_BYTES_PER_PARAMETER +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common import BaseNode, Graph +from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import \ + PruningFrameworkImplementation +from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection, PruningSectionMask +from model_compression_toolkit.logger import Logger + + +class MemoryCalculator: + """ + MemoryCalculator is used for estimating the memory usage of a graph under pruning mask. + It takes into account the specific pruning masks applied to each node in the network, + including handling of shared nodes between pruning sections and consideration of SIMD-padded channels. + The calculator aids in understanding the impact of pruning on the overall memory footprint of the model, + which is crucial for deploying models on memory-constrained devices or optimizing for computational efficiency. + """ + + def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: PruningFrameworkImplementation): + """ + Initializes the MemoryCalculator with necessary information about the model's graph, + framework-specific details, and pruning implementation. + + Args: + graph (Graph): Computational graph of the model. + fw_info (FrameworkInfo): Contains framework-specific information. + fw_impl (PruningFrameworkImplementation): Implementation details for pruning. + """ + self.graph = graph + self.fw_info = fw_info + self.fw_impl = fw_impl + + def get_pruned_graph_memory(self, + masks: Dict[BaseNode, np.ndarray], + include_padded_channels: bool) -> float: + """ + Calculates the memory usage of the pruned graph. + + Args: + masks (Dict[BaseNode, np.ndarray]): Dictionary mapping nodes to their pruning masks. + include_padded_channels (bool): Whether to include padded channels in the memory calculation. + + Returns: + float: Estimated memory usage of the pruned graph in bytes. + """ + nparams = self.get_pruned_graph_num_params(masks, include_padded_channels) + return nparams * FP32_BYTES_PER_PARAMETER # Assuming each parameter is 4 bytes (float32) + + def get_pruned_graph_num_params(self, + masks: Dict[BaseNode, np.ndarray], + include_padded_channels: bool) -> int: + """ + Calculates the total number of parameters in the pruned graph. + + Args: + masks (Dict[BaseNode, np.ndarray]): Pruning masks for each node. + include_padded_channels (bool): Flag to include SIMD-padded channels in the count. + + Returns: + int: Total number of parameters in the pruned graph. + """ + total_nparams = 0 + + pruning_sections = self.graph.get_pruning_sections(self.fw_impl) + total_nparams += self.get_nparams_of_nonpruned_nodes(pruning_sections, include_padded_channels) + total_nparams += self.get_nparams_of_pruning_sections(masks, pruning_sections, include_padded_channels) + total_nparams -= self.get_nparams_of_shared_nodes(masks, pruning_sections, include_padded_channels) + + return total_nparams + + def get_nparams_of_shared_nodes(self, + masks: Dict[BaseNode, np.ndarray], + pruning_sections: List[PruningSection], + include_padded_channels: bool) -> int: + """ + Calculate the number of parameters for nodes shared between adjacent pruning sections. + + Args: + masks (Dict[BaseNode, np.ndarray]): Pruning masks for each node. + pruning_sections (List[PruningSection]): A list of pruning sections. + include_padded_channels (bool): Flag to include padded channels in the count. + + Returns: + int: Total number of parameters for shared nodes. + """ + nparams = 0 + shared_nodes = self._get_nodes_from_adjacent_sections(pruning_sections) + for node in shared_nodes: + node_input_mask = self._get_exit_node_input_mask(node, pruning_sections, masks) + node_output_mask = masks.get(node) + nparams += self.get_pruned_node_num_params(node, node_input_mask, node_output_mask, include_padded_channels) + return nparams + + def get_nparams_of_pruning_sections(self, + masks: Dict[BaseNode, np.ndarray], + pruning_sections: List[PruningSection], + include_padded_channels: bool) -> int: + """ + Calculate the number of parameters for all pruning sections. + + Args: + masks (dict): Pruning masks for each node. + pruning_sections (list): A list of pruning sections. + include_padded_channels (bool): Flag to include padded channels in the count. + + Returns: + int: Total number of parameters for all pruning sections. + """ + nparams = 0 + for pruning_section in pruning_sections: + pruning_section_mask = self.get_section_mask_from_node_mask(masks, pruning_section, pruning_sections) + nparams += self._get_pruning_section_num_params(pruning_section, pruning_section_mask, + include_padded_channels) + return nparams + + def get_section_mask_from_node_mask(self, + masks: Dict[BaseNode, np.ndarray], + pruning_section: PruningSection, + pruning_sections: List[PruningSection]) -> PruningSectionMask: + """ + Create a pruning section mask from individual node masks. + + Args: + masks (dict): Pruning masks for each node. + pruning_section (PruningSection): The current pruning section. + pruning_sections (list): A list of pruning sections. + + Returns: + PruningSectionMask: The combined pruning mask for the section. + """ + first_node_input_channels_mask = self._get_exit_node_input_mask(pruning_section.entry_node, + pruning_sections, + masks) + second_node_output_mask = masks.get(pruning_section.exit_node) + + return PruningSectionMask( + entry_node_ic_mask=first_node_input_channels_mask, + entry_node_oc_mask=masks.get(pruning_section.entry_node), + exit_node_ic_mask=masks.get(pruning_section.entry_node), + exit_node_oc_mask=second_node_output_mask + ) + + def get_nparams_of_nonpruned_nodes(self, + pruning_sections: List[PruningSection], + include_padded_channels: bool) -> int: + """ + Calculate the number of parameters for non-pruned nodes. + + Args: + pruning_sections (list): A list of pruning sections. + include_padded_channels (bool): Flag to include padded channels in the count. + + Returns: + int: Total number of parameters for non-pruned nodes. + """ + total_nparams = 0 + nodes_to_prune = set(node for section in pruning_sections for node in section.get_all_section_nodes()) + for n in self.graph.nodes: + if n not in nodes_to_prune: + node_nparams = self.get_pruned_node_num_params(n, + None, + None, + include_padded_channels) + total_nparams += node_nparams + return total_nparams + + def _get_exit_node_input_mask(self, + node: BaseNode, + pruning_sections: List[PruningSection], + masks: Dict[BaseNode, np.ndarray]) -> np.ndarray: + """ + Retrieves the input mask for an exit node based on the pruning sections. + The function searches for the input channels mask of an exit node based on the output-channels mask + of the corresponding entry node in the graph. If such mask is not found, a mask of 1s is returned. + + Args: + node (BaseNode): The exit node for which the input mask is required. + pruning_sections (List[PruningSection]): A list of pruning sections in the graph. + masks (Dict[BaseNode, np.ndarray]): A dictionary mapping nodes to their respective pruning masks. + + Returns: + np.ndarray: The input mask for the specified exit node, or 1s mask if not found. + """ + for section in pruning_sections: + # If the node is the exit node of a pruning section, return the entry node's mask. + if node == section.exit_node: + return masks.get(section.entry_node) + + kernel_attr = self.fw_info.get_kernel_op_attributes(node.type) + # Ensure only one kernel attribute exists for the given node. + if len(kernel_attr) != 1: + Logger.error(f"Expected to found a single attribute but found {len(kernel_attr)} for node {node}") + kernel_attr = kernel_attr[0] + + # Retrieve and validate the axis index for the output channels. + _, ic_axis = self.fw_info.kernel_channels_mapping.get(node.type) + if ic_axis is None or int(ic_axis) != ic_axis: + Logger.error(f"Expected input channel axis to be an integer but is {ic_axis} for node {node}") + + # Get the number of output channels based on the kernel attribute and axis. + num_ic = node.get_weights_by_keys(kernel_attr).shape[ic_axis] + mask = np.ones(num_ic, dtype=bool) + return mask + + def _get_nodes_from_adjacent_sections(self, + pruning_sections: List[PruningSection]) -> List[BaseNode]: + """ + Identifies nodes that are shared between adjacent pruning sections. + + Args: + pruning_sections (List[PruningSection]): A list of pruning sections in the graph. + + Returns: + List[BaseNode]: A list of nodes that are present at the boundaries of adjacent sections. + """ + input_nodes = set(section.entry_node for section in pruning_sections) + output_nodes = set(section.exit_node for section in pruning_sections) + # Return the intersection of entry and exit nodes, which represents shared nodes. + return list(input_nodes.intersection(output_nodes)) + + def _get_pruning_section_num_params(self, + pruning_section: PruningSection, + pruning_section_mask: PruningSectionMask, + include_padded_channels: bool) -> int: + """ + Calculates the total number of parameters in a pruning section after applying the pruning mask. + + Args: + pruning_section (PruningSection): The pruning section to be considered. + pruning_section_mask (PruningSectionMask): The pruning mask applied to the section. + include_padded_channels (bool): Flag to include padded channels in the count. + + Returns: + int: The total number of parameters in the pruning section after pruning. + """ + # Calculate the number of parameters for the entry node. + first_node_nparams = self.get_pruned_node_num_params(pruning_section.entry_node, + pruning_section_mask.entry_node_ic_mask, + pruning_section_mask.entry_node_oc_mask, + include_padded_channels) + + # Sum the number of parameters for all intermediate nodes. + total_inter_nodes_nparams = sum( + self.get_pruned_node_num_params(inter_node, pruning_section_mask.entry_node_oc_mask, + pruning_section_mask.entry_node_oc_mask, include_padded_channels) for + inter_node in pruning_section.intermediate_nodes) + + # Calculate the number of parameters for the exit node. + second_node_nparams = self.get_pruned_node_num_params(pruning_section.exit_node, + pruning_section_mask.exit_node_ic_mask, + pruning_section_mask.exit_node_oc_mask, + include_padded_channels) + + return first_node_nparams + total_inter_nodes_nparams + second_node_nparams + + def get_pruned_node_num_params(self, + node: BaseNode, + input_mask: np.ndarray, + output_mask: np.ndarray, + include_padded_channels: bool) -> int: + """ + Calculates the number of parameters in a node after applying input and output pruning masks. + + Args: + node (BaseNode): The node whose parameters are to be calculated. + input_mask (np.ndarray): The mask applied to the input channels of the node. + output_mask (np.ndarray): The mask applied to the output channels of the node. + include_padded_channels (bool): Flag to include padded channels in the count due to SIMD. + + Returns: + int: The total number of parameters in the node after pruning. + """ + total_params = 0 + attributes_and_oc_axis = self.fw_impl.attrs_oi_channels_info_for_pruning(node, self.fw_info) + + # Iterate over the node's weights and apply pruning based on the masks. + for w_attr, w in node.weights.items(): + io_axis = [io_axis for attr, io_axis in attributes_and_oc_axis.items() if attr in w_attr] + if len(io_axis) != 1: + Logger.error(f"Each weight should have exactly one corresponding IO axis, but is {io_axis} ") + out_axis, in_axis = io_axis[0] + + # Apply input and output masks to the weight tensor. + if in_axis is not None and input_mask is not None: + w = self._prune_tensor(w, input_mask, in_axis) + if out_axis is not None and output_mask is not None: + w = self._prune_tensor(w, output_mask, out_axis) + + total_params += w.size + + # Adjust the total parameter count if padded channels are to be included. + num_oc = np.sum(output_mask) if output_mask is not None else node.output_shape[-1] + if include_padded_channels: + total_params = self.get_node_nparams_with_padded_channels(node, total_params, num_oc, node.get_simd()) + + return total_params + + def _prune_tensor(self, + w: np.ndarray, + mask: np.ndarray, + axis: int) -> np.ndarray: + """ + Prunes a tensor along a specified axis using a provided mask. + + Args: + w (np.ndarray): The weight tensor to be pruned. + mask (np.ndarray): The pruning mask to apply. + axis (int): The axis along which to apply the pruning mask. + + Returns: + np.ndarray: The pruned tensor. + """ + mask = np.ones(w.shape[axis], dtype=bool) if mask is None else mask.astype(bool) + if w.shape[axis] != len(mask): + Logger.error(f"Expected mask length {len(mask)}, found {w.shape[axis]}.") + pruned_w = np.take(w, np.where(mask)[0], axis=axis) + return pruned_w + + def get_node_nparams_with_padded_channels(self, + node: BaseNode, + node_nparams: int, + num_oc: int, + node_simd: int) -> int: + """ + Adjusts the number of parameters of a node by considering padded channels due to SIMD. + + Args: + node (BaseNode): The node whose parameters are being adjusted. + node_nparams (int): The original number of parameters in the node. + num_oc (int): The number of output channels in the node. + node_simd (int): The SIMD width used in the node. + + Returns: + The adjusted number of parameters considering padded channels. + """ + if not (num_oc >= 1 and int(num_oc) == num_oc): + Logger.error(f"Expected number of output channels to be a non-negative integer but is {num_oc}") + + nparams_per_oc = node_nparams / num_oc + if int(nparams_per_oc) != nparams_per_oc: + Logger.warning( + f"Found a layer {node.name} with weights not uniformly distributed " + f"across output channels; memory calculation may be inaccurate due to " + f"SIMD assumptions.") + nparams_per_oc = np.ceil(nparams_per_oc) + + num_oc_with_null_channels = np.ceil(num_oc / node_simd) * node_simd + return num_oc_with_null_channels * nparams_per_oc diff --git a/model_compression_toolkit/core/common/pruning/prune_graph.py b/model_compression_toolkit/core/common/pruning/prune_graph.py new file mode 100644 index 000000000..596a9e1d5 --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/prune_graph.py @@ -0,0 +1,74 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Dict + +import copy +import numpy as np + +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common import BaseNode, Graph +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.pruning.pruning_section import PruningSectionMask +from model_compression_toolkit.logger import Logger + + +def build_pruned_graph(graph: Graph, + masks: Dict[BaseNode, np.ndarray], + fw_info: FrameworkInfo, + fw_impl: FrameworkImplementation) -> Graph: + """ + Prunes the provided graph according to the given pruning output-channels masks. + + Args: + graph: The original computational graph to be pruned. + masks: A dictionary mapping each prunable node to its pruning mask. + fw_info: Framework-specific information object. + fw_impl: Framework-specific implementation object. + + Returns: + A pruned copy of the original computational graph. + """ + + # Create a deep copy of the graph to avoid modifying the original graph. + graph_to_prune = copy.deepcopy(graph) + + # Get the pruning sections. + pruning_sections = graph_to_prune.get_pruning_sections(fw_impl=fw_impl) + + # Check that each entry node corresponds to a pruning section has an output-channel mask. + if len(pruning_sections) != len(masks): + Logger.error(f"Expected to find same number of masks as number of pruning sections," + f"but {len(masks)} masks were given and found {len(pruning_sections)} pruning sections.") + + # Apply the pruning masks to each pruning section. + for pruning_section in pruning_sections: + + # Retrieve the corresponding mask using the node's name (since we use a graph's copy). + mask = [v for k, v in masks.items() if k.name == pruning_section.entry_node.name] + if len(mask) != 1: + Logger.error(f"Expected to find a single node with name {pruning_section.entry_node.name} in masks dictionary but found {len(mask)}") + mask = mask[0] + + # If the mask indicates that some channels are to be pruned, apply it. + if np.any(mask == 0): + section_mask = PruningSectionMask(entry_node_oc_mask=mask, + exit_node_ic_mask=mask) + pruning_section.apply_inner_section_mask(section_mask, + fw_impl, + fw_info) + + return graph_to_prune + diff --git a/model_compression_toolkit/core/common/pruning/pruner.py b/model_compression_toolkit/core/common/pruning/pruner.py new file mode 100644 index 000000000..250ac957c --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/pruner.py @@ -0,0 +1,132 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +from typing import Callable, List, Dict, Tuple + +from model_compression_toolkit.core.common import Graph, BaseNode +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI +from model_compression_toolkit.core.common.pruning.greedy_mask_calculator import GreedyMaskCalculator +from model_compression_toolkit.core.common.pruning.importance_metrics.importance_metric_factory import \ + get_importance_metric +from model_compression_toolkit.core.common.pruning.prune_graph import build_pruned_graph +from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig, ChannelsFilteringStrategy +from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import \ + PruningFrameworkImplementation +from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo, \ + unroll_simd_scores_to_per_channel_scores +from model_compression_toolkit.logger import Logger +from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities + +class Pruner: + """ + Pruner class responsible for applying pruning to a computational graph to meet a target KPI. + It identifies and prunes less significant channels based on importance scores, considering SIMD constraints. + """ + def __init__(self, + float_graph: Graph, + fw_info: FrameworkInfo, + fw_impl: PruningFrameworkImplementation, + target_kpi: KPI, + representative_data_gen: Callable, + pruning_config: PruningConfig, + target_platform_capabilities: TargetPlatformCapabilities): + """ + Args: + float_graph (Graph): The floating-point representation of the model's computation graph. + fw_info (FrameworkInfo): Contains metadata and helper functions for the framework. + fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning. + target_kpi (KPI): The target KPIs to be achieved after pruning. + representative_data_gen (Callable): Generator function for representative dataset used in pruning analysis. + pruning_config (PruningConfig): Configuration object specifying how pruning should be performed. + target_platform_capabilities (TargetPlatformCapabilities): Object encapsulating the capabilities of the target hardware platform. + """ + self.float_graph = float_graph + self.fw_info = fw_info + self.fw_impl = fw_impl + self.target_kpi = target_kpi + self.representative_data_gen = representative_data_gen + self.pruning_config = pruning_config + self.target_platform_capabilities = target_platform_capabilities + + # Internal variables for storing the pruned graph and intermediate data. + self.per_oc_mask = None # Output-channel mask for each entry node. + self.simd_scores = None # Importance scores considering SIMD groups. + self.simd_groups_indices = None # Indices of SIMD groups in each node. + + def prune_graph(self): + """ + Main method for pruning the graph. Computes importance scores, calculates pruning masks, + and constructs the pruned graph based on these masks. + """ + # Fetch entry nodes and compute importance scores for SIMD groups. + entry_nodes = self.float_graph.get_pruning_sections_entry_nodes(self.fw_impl) + self.simd_scores, self.simd_groups_indices = self.get_score_per_entry_point(entry_nodes) + + Logger.info(f"Calculating the pruning mask. Please note that this process might take some time," + f" especially for large models or when using a small SIMD size.") + + # Apply Greedy strategy to compute masks based on importance scores. + if self.pruning_config.channels_filtering_strategy == ChannelsFilteringStrategy.GREEDY: + mask_calculator = GreedyMaskCalculator(entry_nodes, + self.fw_info, + self.simd_scores, + self.target_kpi, + self.float_graph, + self.fw_impl, + self.target_platform_capabilities, + self.simd_groups_indices) + mask_calculator.compute_mask() + self.per_oc_mask = mask_calculator.get_mask() + else: + Logger.error("Only GREEDY ChannelsFilteringStrategy is currently supported.") + + Logger.info("Start pruning graph...") + _pruned_graph = build_pruned_graph(self.float_graph, + self.per_oc_mask, + self.fw_info, + self.fw_impl) + return _pruned_graph + + def get_score_per_entry_point(self, entry_nodes: List[BaseNode]) -> Tuple[Dict[BaseNode, np.ndarray], Dict[BaseNode, List[np.ndarray]]]: + """ + Calculates the importance score for each entry node in the graph. + + Args: + entry_nodes (List[BaseNode]): List of entry nodes in the graph. + + Returns: + Tuple: Tuple containing importance scores and group indices. + """ + # Retrieve and initialize the importance metric. + im = get_importance_metric(self.pruning_config.importance_metric, graph=self.float_graph, + representative_data_gen=self.representative_data_gen, fw_impl=self.fw_impl, + pruning_config=self.pruning_config, fw_info=self.fw_info) + entry_node_to_simd_score, simd_groups_indices = im.get_entry_node_to_simd_score(entry_nodes) + return entry_node_to_simd_score, simd_groups_indices + + def get_pruning_info(self) -> PruningInfo: + """ + Compiles and returns detailed pruning information, including masks and channel scores. + + Returns: + PruningInfo: Object containing detailed pruning data. + """ + # Convert SIMD group scores to per-channel scores and create PruningInfo. + _per_oc_scores = unroll_simd_scores_to_per_channel_scores(self.simd_scores, self.simd_groups_indices) + info = PruningInfo(pruning_masks=self.per_oc_mask, importance_scores=_per_oc_scores) + return info + diff --git a/model_compression_toolkit/core/common/pruning/pruning_config.py b/model_compression_toolkit/core/common/pruning/pruning_config.py new file mode 100644 index 000000000..4476eb41f --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/pruning_config.py @@ -0,0 +1,73 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from enum import Enum + +from model_compression_toolkit.constants import PRUNING_NUM_SCORE_APPROXIMATIONS + + +class ImportanceMetric(Enum): + """ + Enum for specifying the metric used to determine the importance of channels when pruning. + """ + LFH = 0 # Score based on the Hessian matrix w.r.t. layers weights, to determine channel importance without labels. + + +class ChannelsFilteringStrategy(Enum): + """ + Enum for specifying the strategy used for filtering (pruning) channels. + """ + GREEDY = 0 # Greedy strategy for pruning channels based on importance metrics. + + +class PruningConfig: + """ + Configuration class for specifying how a neural network should be pruned. + + Attributes: + num_score_approximations (int): The number of score approximations to perform + when calculating channel importance. + importance_metric (ImportanceMetric): The metric used to calculate channel importance. + channels_filtering_strategy (ChannelsFilteringStrategy): The strategy used to filter out channels. + """ + + def __init__(self, + num_score_approximations: int = PRUNING_NUM_SCORE_APPROXIMATIONS, + importance_metric: ImportanceMetric = ImportanceMetric.LFH, + channels_filtering_strategy: ChannelsFilteringStrategy = ChannelsFilteringStrategy.GREEDY): + """ + Initializes a PruningConfig object with default or specified parameters. + + Args: + num_score_approximations (int): The number of times to approximate the scoring + for channel importance. Defaults to a predefined + constant value. + importance_metric (ImportanceMetric): The method used for calculating the importance + of channels in a network. Defaults to label-free + Hessian (LFH) approximation. + channels_filtering_strategy (ChannelsFilteringStrategy): The strategy for selecting + which channels to prune. + Defaults to a greedy approach. + """ + + # The number of times the importance score is approximated. + self.num_score_approximations = num_score_approximations + + # The metric used to assess the importance of each channel in a layer. + self.importance_metric = importance_metric + + # The strategy to use when deciding which channels to prune based on their importance scores. + self.channels_filtering_strategy = channels_filtering_strategy + diff --git a/model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py b/model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py new file mode 100644 index 000000000..6344f1841 --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py @@ -0,0 +1,156 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import List, Tuple, Dict + +from abc import abstractmethod + +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common import BaseNode +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +import numpy as np + + +class PruningFrameworkImplementation(FrameworkImplementation): + + @abstractmethod + def prune_entry_node(self, + node: BaseNode, + output_mask: np.ndarray, + fw_info: FrameworkInfo): + """ + Abstract method to prune an entry node in the model. + + Args: + node: The node to be pruned. + output_mask: A numpy array representing the mask to be applied to the output channels. + fw_info: Framework-specific information. + + Raises: + NotImplemented: If the method is not implemented in the subclass. + """ + raise NotImplemented(f'{self.__class__.__name__} have to implement the ' + f'framework\'s prune_entry_node method.') # pragma: no cover + + @abstractmethod + def prune_intermediate_node(self, + node: BaseNode, + input_mask: np.ndarray, + output_mask: np.ndarray, + fw_info: FrameworkInfo): + """ + Abstract method to prune an intermediate node in the model. + + Args: + node: The node to be pruned. + input_mask: Mask to be applied to the input channels. + output_mask: Mask to be applied to the output channels. + fw_info: Framework-specific information. + + Raises: + NotImplemented: If the method is not implemented in the subclass. + """ + raise NotImplemented(f'{self.__class__.__name__} have to implement the ' + f'framework\'s prune_intermediate_node method.') # pragma: no cover + + @abstractmethod + def prune_exit_node(self, + node: BaseNode, + input_mask: np.ndarray, + fw_info: FrameworkInfo): + """ + Abstract method to prune an exit node in the model. + + Args: + node: The node to be pruned. + input_mask: Mask to be applied to the input channels. + fw_info: Framework-specific information. + + Raises: + NotImplemented: If the method is not implemented in the subclass. + """ + raise NotImplemented(f'{self.__class__.__name__} have to implement the ' + f'framework\'s prune_exit_node method.') # pragma: no cover + + @abstractmethod + def is_node_entry_node(self, + node: BaseNode) -> bool: + """ + Abstract method to determine if a given node is an entry node. + + Args: + node: The node to be checked. + + Returns: + bool: True if the node is an entry node, False otherwise. + + Raises: + NotImplemented: If the method is not implemented in the subclass. + """ + raise NotImplemented(f'{self.__class__.__name__} have to implement the ' + f'framework\'s is_node_entry_node method.') # pragma: no cover + + @abstractmethod + def is_node_exit_node(self, + node: BaseNode, + corresponding_entry_node: BaseNode, + fw_info: FrameworkInfo) -> bool: + + raise NotImplemented(f'{self.__class__.__name__} have to implement the ' + f'framework\'s is_node_exit_node method.') # pragma: no cover + + @abstractmethod + def is_node_intermediate_pruning_section(self, + node: BaseNode) -> bool: + """ + Abstract method to determine if a given node is in the intermediate section of pruning. + + Args: + node: The node to be checked. + + Returns: + bool: True if the node is in the intermediate pruning section, False otherwise. + + Raises: + NotImplemented: If the method is not implemented in the subclass. + """ + raise NotImplemented(f'{self.__class__.__name__} have to implement the ' + f'framework\'s is_node_intermediate_pruning_section method.') # pragma: no cover + + def attrs_oi_channels_info_for_pruning(self, node: BaseNode, fw_info: FrameworkInfo) -> Dict[str, Tuple[int, int]]: + """ + Retrieves the attributes of a given node along with the output/input (OI) channel axis + for each attribute used to prune these attributes. + + Not all attributes of a node are directly associated with both input and output channels. + For example, bias vectors in convolutional layers are solely related to the number of output + channels and do not have a corresponding input channel dimension. + In cases like that, None is returned in the tuple of axis for such attributes. + + For kernel operations (like convolutions), the function identifies the output and input + channel axis based on framework-specific information. + For non-kernel operations, it defaults to setting the last axis as the output + channel axis, assuming no specific input channel axis. + + Args: + node (BaseNode): The node from the computational graph. + fw_info (FrameworkInfo): Contains framework-specific information and utilities. + + Returns: + Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'kernel' or 'bias') + and each value is a tuple representing the output and input channel axis indices respectively. + """ + raise NotImplemented(f'{self.__class__.__name__} have to implement the ' + f'framework\'s attrs_oi_channels_info_for_pruning method.') # pragma: no cover \ No newline at end of file diff --git a/model_compression_toolkit/core/common/pruning/pruning_info.py b/model_compression_toolkit/core/common/pruning/pruning_info.py new file mode 100644 index 000000000..7226aa060 --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/pruning_info.py @@ -0,0 +1,97 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Dict, List +import numpy as np + +from model_compression_toolkit.core.common import BaseNode +from model_compression_toolkit.logger import Logger + + +class PruningInfo: + """ + PruningInfo stores information about a pruned model, including the pruning masks + and importance scores for each layer. This class acts as a container for accessing + pruning-related metadata. + + Attributes: + pruning_masks (Dict[BaseNode, np.ndarray]): Stores the pruning masks for each layer. + A pruning mask is an array where each element indicates whether the corresponding + channel or neuron has been pruned (0) or kept (1). + importance_scores (Dict[BaseNode, np.ndarray]): Stores the importance scores for each layer. + Importance scores quantify the significance of each channel in the layer. + """ + + def __init__(self, + pruning_masks: Dict[BaseNode, np.ndarray], + importance_scores: Dict[BaseNode, np.ndarray]): + """ + Initializes the PruningInfo with pruning masks and importance scores. + + Args: + pruning_masks (Dict[BaseNode, np.ndarray]): Pruning masks for each layer. + importance_scores (Dict[BaseNode, np.ndarray]): Importance scores for each layer. + """ + self._pruning_masks = pruning_masks + self._importance_scores = importance_scores + + @property + def pruning_masks(self) -> Dict[BaseNode, np.ndarray]: + """ + The pruning masks for each layer. + + Returns: + Dict[BaseNode, np.ndarray]: The pruning masks. + """ + return self._pruning_masks + + @property + def importance_scores(self) -> Dict[BaseNode, np.ndarray]: + """ + The importance scores for each layer. + + Returns: + Dict[BaseNode, np.ndarray]: The importance scores. + """ + return self._importance_scores + +def unroll_simd_scores_to_per_channel_scores(simd_scores: Dict[BaseNode, np.ndarray], + simd_groups_indices: Dict[BaseNode, List[np.ndarray]]) -> Dict[BaseNode, np.ndarray]: + """ + Expands SIMD group scores into per-channel scores. This is necessary when channels + are grouped in SIMD groups, and a single score is assigned to each group. The function + duplicates the group score to each channel in that group. + + Args: + simd_scores (Dict[BaseNode, np.ndarray]): The scores assigned to each SIMD group. + simd_groups_indices (Dict[BaseNode, List[np.ndarray]]): The indices of channels in each SIMD group. + + Returns: + Dict[BaseNode, np.ndarray]: Expanded scores for each individual channel. + """ + if simd_scores is None or simd_groups_indices is None: + Logger.error(f"Found to find scores and indices to create an unrolled scores for pruning info," + f"but scores is {simd_scores} and groups indices are {simd_groups_indices}") + _scores = {} + for node, groups_indices in simd_groups_indices.items(): + node_scores = simd_scores[node] + total_indices = sum(len(group) for group in groups_indices) + new_node_scores = np.zeros(total_indices) + + for group_score, group_indices in zip(node_scores, groups_indices): + new_node_scores[group_indices] = group_score + + _scores[node] = new_node_scores + return _scores diff --git a/model_compression_toolkit/core/common/pruning/pruning_section.py b/model_compression_toolkit/core/common/pruning/pruning_section.py new file mode 100644 index 000000000..269e62fb6 --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/pruning_section.py @@ -0,0 +1,127 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import List, Any + +import numpy as np + +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common.graph.base_node import BaseNode + + +class PruningSectionMask: + """ + Represents the masks to be applied to a pruning section of a neural network. + This includes masks for both input and output channels at the entry and exit nodes of the section. + + Attributes: + entry_node_ic_mask (np.ndarray): Mask for input channels of the entry node. + entry_node_oc_mask (np.ndarray): Mask for output channels of the entry node. + exit_node_ic_mask (np.ndarray): Mask for input channels of the exit node. + exit_node_oc_mask (np.ndarray): Mask for output channels of the exit node. + """ + + def __init__(self, + entry_node_ic_mask: np.ndarray = None, + entry_node_oc_mask: np.ndarray = None, + exit_node_ic_mask: np.ndarray = None, + exit_node_oc_mask: np.ndarray = None): + self.entry_node_ic_mask = entry_node_ic_mask + self.entry_node_oc_mask = entry_node_oc_mask + self.exit_node_ic_mask = exit_node_ic_mask + self.exit_node_oc_mask = exit_node_oc_mask + + +class PruningSection: + """ + Represents a section in a graph to be pruned, consisting of an entry node, + intermediate nodes, and an exit node. + + Attributes: + entry_node (BaseNode): The first node in the pruning section. + intermediate_nodes (List[BaseNode]): List of nodes between the entry and exit nodes. + exit_node (BaseNode): The last node in the pruning section. + """ + + def __init__(self, + entry_node: BaseNode, + intermediate_nodes: List[BaseNode], + exit_node: BaseNode): + self.entry_node = entry_node + self.intermediate_nodes = intermediate_nodes + self.exit_node = exit_node + + def get_all_section_nodes(self) -> List[BaseNode]: + """ + Returns a list of all nodes in the pruning section, including the entry, + intermediate, and exit nodes. + + Returns: + List[BaseNode]: List of all nodes in the pruning section. + """ + nodes = [self.entry_node] + self.intermediate_nodes + [self.exit_node] + return nodes + + def apply_inner_section_mask(self, + pruning_section_mask: PruningSectionMask, + fw_impl: Any, + fw_info: FrameworkInfo): + """ + Apply the provided pruning section mask to all nodes within the pruning section. + + Args: + pruning_section_mask (PruningSectionMask): The mask to be applied to the pruning section. + fw_impl (PruningFrameworkImplementation): Framework-specific implementation for applying the mask. + fw_info (FrameworkInfo): Framework-specific information needed to apply the mask. + """ + fw_impl.prune_entry_node(node=self.entry_node, + output_mask=pruning_section_mask.entry_node_oc_mask, + fw_info=fw_info) + + for inter_node in self.intermediate_nodes: + fw_impl.prune_intermediate_node(node=inter_node, + input_mask=pruning_section_mask.entry_node_oc_mask, + output_mask=pruning_section_mask.entry_node_oc_mask, + fw_info=fw_info) + + fw_impl.prune_exit_node(self.exit_node, + input_mask=pruning_section_mask.exit_node_ic_mask, + fw_info=fw_info) + + @staticmethod + def has_matching_channel_count(exit_node: BaseNode, + corresponding_entry_node: BaseNode, + fw_info: FrameworkInfo) -> bool: + """ + Checks if the number of input channels of the exit node matches the number of output channels + of its corresponding entry node. + + Args: + exit_node (BaseNode): The node exiting a pruning section. + corresponding_entry_node (BaseNode): The entry node of the subsequent pruning section. + + Returns: + bool: True if the channel counts match, False otherwise. + """ + _, exit_input_channel_axis = fw_info.kernel_channels_mapping.get(exit_node.type) + entry_output_channel_axis, _ = fw_info.kernel_channels_mapping.get(corresponding_entry_node.type) + + exit_node_attr = fw_info.get_kernel_op_attributes(exit_node.type)[0] + entry_node_attr = fw_info.get_kernel_op_attributes(corresponding_entry_node.type)[0] + + exit_input_channels = exit_node.get_weights_by_keys(exit_node_attr).shape[exit_input_channel_axis] + entry_output_channels = corresponding_entry_node.get_weights_by_keys(entry_node_attr).shape[entry_output_channel_axis] + + return exit_input_channels == entry_output_channels diff --git a/model_compression_toolkit/core/keras/pruning/__init__.py b/model_compression_toolkit/core/keras/pruning/__init__.py new file mode 100644 index 000000000..cb2075e68 --- /dev/null +++ b/model_compression_toolkit/core/keras/pruning/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + diff --git a/model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py b/model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py new file mode 100644 index 000000000..34764589f --- /dev/null +++ b/model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py @@ -0,0 +1,285 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import List, Tuple, Dict + +from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import \ + PruningFrameworkImplementation +from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection +from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common import BaseNode +from model_compression_toolkit.core.keras.constants import BIAS, GROUPS, FILTERS, UNITS, USE_BIAS +import keras + +import numpy as np + +from model_compression_toolkit.logger import Logger + + +class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementation): + """ + Implementation of the PruningFramework for the Keras framework. This class provides + concrete implementations of the abstract methods defined in PruningFrameworkImplementation + for the Keras framework. + """ + + def prune_entry_node(self, + node: BaseNode, + output_mask: np.ndarray, + fw_info: FrameworkInfo): + """ + Prunes the entry node of a model in Keras. + + Args: + node: The entry node to be pruned. + output_mask: A numpy array representing the mask to be applied to the output channels. + fw_info: Framework-specific information object. + + """ + return _prune_keras_edge_node(node=node, + mask=output_mask, + fw_info=fw_info, + is_exit_node=False) + + def prune_intermediate_node(self, + node: BaseNode, + input_mask: np.ndarray, + output_mask: np.ndarray, + fw_info: FrameworkInfo): + """ + Prunes an intermediate node in a Keras model. + + Args: + node: The intermediate node to be pruned. + input_mask: A numpy array representing the mask to be applied to the input channels. + output_mask: A numpy array representing the mask to be applied to the output channels. + fw_info: Framework-specific information object. + + """ + _edit_node_input_shape(input_mask, node) + pruned_parameters = {} + mask_bool = output_mask.astype(bool) + for k, v in node.weights.items(): + # Apply the mask to the weights. + pruned_parameters[k] = v.compress(mask_bool, axis=-1) + node.weights = pruned_parameters + + def prune_exit_node(self, + node: BaseNode, + input_mask: np.ndarray, + fw_info: FrameworkInfo): + """ + Prunes the exit node of a model in Keras. + + Args: + node: The exit node to be pruned. + input_mask: A numpy array representing the mask to be applied to the input channels. + fw_info: Framework-specific information object. + + """ + return _prune_keras_edge_node(node=node, + mask=input_mask, + fw_info=fw_info, + is_exit_node=True) + + def is_node_entry_node(self, node: BaseNode) -> bool: + """ + Determines whether a node is an entry node in a Keras model. + + Args: + node: The node to be checked. + + Returns: + Boolean indicating if the node is an entry node. + """ + return _is_keras_node_pruning_section_edge(node) + + def is_node_exit_node(self, + node: BaseNode, + corresponding_entry_node: BaseNode, + fw_info: FrameworkInfo) -> bool: + """ + Determines whether a node is an exit node in a Keras model. + + Args: + node: The node to be checked. + corresponding_entry_node: The entry node of the pruning section that is checked. + fw_info: Framework-specific information object. + + Returns: + Boolean indicating if the node is an exit node. + """ + return _is_keras_node_pruning_section_edge(node) and PruningSection.has_matching_channel_count(node, + corresponding_entry_node, + fw_info) + + def is_node_intermediate_pruning_section(self, node) -> bool: + """ + Determines whether a node is part of the intermediate section in the pruning process of a Keras model. + + Args: + node: The node to be checked. + + Returns: + Boolean indicating if the node is part of the intermediate pruning section. + """ + # Nodes that are not Conv2D, Conv2DTranspose, DepthwiseConv2D, or Dense are considered intermediate. + return node.type not in [keras.layers.DepthwiseConv2D, + keras.layers.Conv2D, + keras.layers.Conv2DTranspose, + keras.layers.Dense] + + def attrs_oi_channels_info_for_pruning(self, + node: BaseNode, + fw_info: FrameworkInfo) -> Dict[str, Tuple[int, int]]: + """ + Retrieves the attributes of a given node along with the output/input (OI) channel axis + for each attribute used to prune these attributes. + + Not all attributes of a node are directly associated with both input and output channels. + For example, bias vectors in convolutional layers are solely related to the number of output + channels and do not have a corresponding input channel dimension. + In cases like that, None is returned in the tuple of axis for such attributes. + + For kernel operations (like convolutions), the function identifies the output and input + channel axis based on framework-specific information. + For non-kernel operations, it defaults to setting the last axis as the output + channel axis, assuming no specific input channel axis. + + Args: + node (BaseNode): The node from the computational graph. + fw_info (FrameworkInfo): Contains framework-specific information and utilities. + + Returns: + Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'kernel' or 'bias') + and each value is a tuple representing the output and input channel axis indices respectively. + """ + + attributes_with_axis = {} + if fw_info.is_kernel_op(node.type): + kernel_attributes = fw_info.get_kernel_op_attributes(node.type) + if kernel_attributes is None or len(kernel_attributes)==0: + Logger.error(f"Expected to find attributes but found {kernel_attributes}") + + for attr in kernel_attributes: + attributes_with_axis[attr] = fw_info.kernel_channels_mapping.get(node.type) + + # Bias is a vector at the length of the number of output channels. + # For this reason, input channel axis is irrelevant to the bias attribute. + attributes_with_axis[BIAS] = (0, None) + else: + # We have several assumptions here: + # 1. For intermediate nodes, we prune all nodes' weights. + # 2. The output channel axis is the last axis of this attribute. + # 3. The input channel axis is irrelevant since these attributes are pruned only by + # their output channels. + for attr in list(node.weights.keys()): + attributes_with_axis[attr] = (-1, None) + + return attributes_with_axis + + +def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool: + """ + Determines if a Keras node is an edge of a pruning section. + + In the context of pruning, an 'edge' node is a layer that can potentially be pruned. + This function identifies such nodes based on their type and attributes. Specifically, + Conv2D and Conv2DTranspose layers with 'groups' attribute set to 1, and Dense layers + are considered as edges for pruning sections. + + Args: + node (BaseNode): The node to be evaluated. + + Returns: + bool: True if the node is an edge of a pruning section, False otherwise. + """ + + # Check if the node is a Conv2D or Conv2DTranspose layer with groups set to 1. + if node.type in [keras.layers.Conv2D, keras.layers.Conv2DTranspose]: + return node.framework_attr[GROUPS] == 1 + return node.type == keras.layers.Dense + + + +def _prune_keras_edge_node(node: BaseNode, + mask: np.ndarray, + fw_info: FrameworkInfo, + is_exit_node: bool): + """ + Prunes the given Keras node by applying the mask to the node's weights (kernels and biases). + This function can handle both entry and exit nodes by specifying the is_exit_node parameter. + + Args: + node: The node to be pruned. + mask: The pruning mask to be applied. + fw_info: Framework-specific information object. + is_exit_node: A boolean indicating whether the node is an exit node. + + """ + + # Retrieve the kernel attribute and the axes to prune. + kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0] + io_axis = fw_info.kernel_channels_mapping.get(node.type) + axis_to_prune = io_axis[int(is_exit_node)] + kernel = node.get_weights_by_keys(kernel_attr) + # Convert mask to boolean. + mask_bool = mask.astype(bool) + + pruned_kernel = kernel.compress(mask_bool, axis=axis_to_prune) + node.set_weights_by_keys(name=kernel_attr, tensor=pruned_kernel) + + if not is_exit_node and node.framework_attr[USE_BIAS]: + # Prune the bias if applicable and it's an entry node. + bias = node.get_weights_by_keys(BIAS) + pruned_bias = bias.compress(mask_bool) + node.set_weights_by_keys(name=BIAS, tensor=pruned_bias) + + if not is_exit_node: + # Update 'filters' or 'units' attributes for entry node Conv2D/Conv2DTranspose layers. + if node.type in [keras.layers.Conv2D, keras.layers.Conv2DTranspose]: + node.framework_attr[FILTERS] = int(np.sum(mask)) + elif node.type == keras.layers.Dense: + node.framework_attr[UNITS] = int(np.sum(mask)) + + if is_exit_node: + # Adjust the input shape for the last node in the section. + _edit_node_input_shape(mask_bool, node) + + +def _edit_node_input_shape(input_mask: np.ndarray, + node: BaseNode): + """ + Adjusts the input shape of a node based on the given input mask. + + This function modifies the input shape of the given node to reflect the pruning + that has taken place. It updates the last dimension of the node's input shape + to match the number of channels that remain after pruning. + + Args: + input_mask (np.ndarray): A binary array where 1 indicates the channel is kept and 0 means pruned. + node (BaseNode): The node whose input shape needs to be adjusted. + """ + # Start with the current input shape of the node. + new_input_shape = list(node.input_shape) + + # Adjust the last dimension of the shape to match the number of unpruned (retained) channels. + # This is done by summing the mask, as each '1' in the mask represents a retained channel. + new_input_shape[-1] = int(np.sum(input_mask)) + + # Update the node's input shape with the new dimensions. + node.input_shape = tuple(new_input_shape) + diff --git a/model_compression_toolkit/pruning/__init__.py b/model_compression_toolkit/pruning/__init__.py new file mode 100644 index 000000000..9ce528a81 --- /dev/null +++ b/model_compression_toolkit/pruning/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo +from model_compression_toolkit.core.common.pruning.pruning_config import ImportanceMetric, PruningConfig, ChannelsFilteringStrategy +from model_compression_toolkit.pruning.keras.pruning_facade import keras_pruning_experimental + diff --git a/model_compression_toolkit/pruning/keras/__init__.py b/model_compression_toolkit/pruning/keras/__init__.py new file mode 100644 index 000000000..cb2075e68 --- /dev/null +++ b/model_compression_toolkit/pruning/keras/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + diff --git a/model_compression_toolkit/pruning/keras/pruning_facade.py b/model_compression_toolkit/pruning/keras/pruning_facade.py new file mode 100644 index 000000000..14d9da260 --- /dev/null +++ b/model_compression_toolkit/pruning/keras/pruning_facade.py @@ -0,0 +1,148 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Callable, Tuple + +from model_compression_toolkit import get_target_platform_capabilities +from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF +from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI +from model_compression_toolkit.core.common.pruning.pruner import Pruner +from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig +from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo +from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph +from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph +from model_compression_toolkit.logger import Logger +from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities +from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG +from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL + +if FOUND_TF: + from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder + from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation + from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO + from tensorflow.keras.models import Model + + DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL) + + def keras_pruning_experimental(model: Model, + target_kpi: KPI, + representative_data_gen: Callable, + pruning_config: PruningConfig = PruningConfig(), + target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> \ + Tuple[Model, PruningInfo]: + """ + Perform structured pruning on a Keras model to meet a specified target KPI. + This function prunes the provided model according to the target KPI by grouping and pruning + channels based on each layer's SIMD configuration in the Target Platform Capabilities (TPC). + By default, the importance of each channel group is determined using the Label-Free Hessian + (LFH) method, assessing each channel's sensitivity to the Hessian of the loss function. + This pruning strategy considers groups of channels together for a more hardware-friendly + architecture. The process involves analyzing the model with a representative dataset to + identify groups of channels that can be removed with minimal impact on performance. + + Notice that the pruned model must be retrained to recover the compressed model's performance. + + Args: + model (Model): The original Keras model to be pruned. + target_kpi (KPI): The target Key Performance Indicators to be achieved through pruning. + representative_data_gen (Callable): A function to generate representative data for pruning analysis. + pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config. + target_platform_capabilities (TargetPlatformCapabilities): Platform-specific constraints and capabilities. + Defaults to DEFAULT_KERAS_TPC. + + Returns: + Tuple[Model, PruningInfo]: A tuple containing the pruned Keras model and associated pruning information. + + Examples: + + Import MCT: + + >>> import model_compression_toolkit as mct + + Import a Keras model: + + >>> from tensorflow.keras.applications.resnet50 import ResNet50 + >>> model = ResNet50() + + Create a random dataset generator: + + >>> import numpy as np + >>> def repr_datagen(): yield [np.random.random((1, 224, 224, 3))] + + Define a target KPI for pruning. + Here, we aim to reduce the memory footprint of weights by 50%, assuming the model weights + are represented in float32 data type (thus, each parameter is represented using 4 bytes): + + >>> dense_nparams = sum([l.count_params() for l in model.layers]) + >>> target_kpi = mct.KPI(weights_memory=dense_nparams * 4 * 0.5) + + Optionally, define a pruning configuration. num_score_approximations can be passed + to configure the number of importance scores that will be calculated for each channel. + A higher value for this parameter yields more precise score approximations but also + extends the duration of the pruning process: + + >>> pruning_config = mct.pruning.PruningConfig(num_score_approximations=1) + + Perform pruning: + + >>> pruned_model, pruning_info = mct.pruning.keras_pruning_experimental(model=model, target_kpi=target_kpi, representative_data_gen=repr_datagen, pruning_config=pruning_config) + + """ + + # Instantiate the Keras framework implementation. + fw_impl = PruningKerasImplementation() + + # Convert the original Keras model to an internal graph representation. + float_graph = read_model_to_graph(model, + representative_data_gen, + target_platform_capabilities, + DEFAULT_KERAS_INFO, + fw_impl) + + # Apply quantization configuration to the graph. This step is necessary even when not quantizing, + # as it prepares the graph for the pruning process. + float_graph_with_compression_config = set_quantization_configuration_to_graph(float_graph, + quant_config=DEFAULTCONFIG, + mixed_precision_enable=False) + + # Create a Pruner object with the graph and configuration. + pruner = Pruner(float_graph_with_compression_config, + DEFAULT_KERAS_INFO, + fw_impl, + target_kpi, + representative_data_gen, + pruning_config, + target_platform_capabilities) + + # Apply the pruning process. + pruned_graph = pruner.prune_graph() + + # Retrieve pruning information which includes the pruning masks and scores. + pruning_info = pruner.get_pruning_info() + + # Rebuild the pruned graph back into a trainable Keras model. + pruned_model, _ = FloatKerasModelBuilder(graph=pruned_graph).build_model() + pruned_model.trainable = True + + # Return the pruned model along with its pruning information. + return pruned_model, pruning_info + +else: + # If tensorflow is not installed, + # we raise an exception when trying to use these functions. + def keras_pruning_experimental(*args, **kwargs): + Logger.critical('Installing tensorflow is mandatory ' + 'when using keras_pruning_experimental. ' + 'Could not find Tensorflow package.') # pragma: no cover diff --git a/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py b/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py index 45d0d9269..6b4629616 100644 --- a/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +++ b/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py @@ -36,8 +36,7 @@ def __init__(self, fixed_scale: float, fixed_zero_point: int, weights_multiplier_nbits: int, # If None - set 8 in hptq, o.w use it - simd_size: int - ): + simd_size: int): """ Args: diff --git a/model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py b/model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py index 8eb5d5c6c..2d9f95680 100644 --- a/model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +++ b/model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py @@ -76,6 +76,7 @@ def __init__(self, f'Default QuantizationConfigOptions must contain only one option' self.default_qco = default_qco self.fusing_patterns = [] + self.is_simd_padding = False def get_config_options_by_operators_set(self, operators_set_name: str) -> QuantizationConfigOptions: @@ -224,3 +225,15 @@ def show(self): """ pprint.pprint(self.get_info(), sort_dicts=False) + def set_simd_padding(self, + is_simd_padding: bool): + """ + Set flag is_simd_padding to indicate whether this TP model defines + that padding due to SIMD constrains occurs. + + Args: + is_simd_padding: Whether this TP model defines that padding due to SIMD constrains occurs. + + """ + self.is_simd_padding = is_simd_padding + diff --git a/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py index 402671872..304a25526 100644 --- a/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +++ b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py @@ -223,3 +223,12 @@ def raise_warnings(self): """ for op in self.__tp_model_opsets_not_used: Logger.warning(f'{op} is defined in TargetPlatformModel, but is not used in TargetPlatformCapabilities.') + + @property + def is_simd_padding(self) -> bool: + """ + + Returns: Check if the TP model defines that padding due to SIMD constrains occurs. + + """ + return self.tp_model.is_simd_padding \ No newline at end of file diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py index d52c64359..c4e5e52d0 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py @@ -118,6 +118,8 @@ def generate_tp_model(default_config: OpQuantizationConfig, # be used for operations that will be attached to this set's label. # Otherwise, it will be a configure-less set (used in fusing): + generated_tpc.set_simd_padding(is_simd_padding=True) + # May suit for operations like: Dropout, Reshape, etc. tp.OperatorsSet("NoQuantization", tp.get_default_quantization_config_options().clone_and_edit( diff --git a/tests/keras_tests/pruning_tests/__init__.py b/tests/keras_tests/pruning_tests/__init__.py new file mode 100644 index 000000000..cb2075e68 --- /dev/null +++ b/tests/keras_tests/pruning_tests/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + diff --git a/tests/keras_tests/pruning_tests/feature_networks/__init__.py b/tests/keras_tests/pruning_tests/feature_networks/__init__.py new file mode 100644 index 000000000..2147ec284 --- /dev/null +++ b/tests/keras_tests/pruning_tests/feature_networks/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests/keras_tests/pruning_tests/feature_networks/constant_importance_metric.py b/tests/keras_tests/pruning_tests/feature_networks/constant_importance_metric.py new file mode 100644 index 000000000..47793bf92 --- /dev/null +++ b/tests/keras_tests/pruning_tests/feature_networks/constant_importance_metric.py @@ -0,0 +1,83 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from enum import Enum + +from typing import List + +from model_compression_toolkit.core.common import BaseNode +from model_compression_toolkit.core.common.pruning.importance_metrics.base_importance_metric import BaseImportanceMetric +import numpy as np + +from model_compression_toolkit.core.common.pruning.importance_metrics.importance_metric_factory import IMPORTANCE_METRIC_DICT + + +class ConstantImportanceMetric(BaseImportanceMetric): + """ + ConstantImportanceMetric is used for testing architectures with three linear layers in a row. + It assigns scores in reverse order of the channel index. It generates constant scores and + grouped indices for the first two layers based on predefined numbers of output channels. + """ + + # Static attributes to hold the predefined number of output channels for the first two layers. + first_num_oc = None + second_num_oc = None + simd = 1 + + def __init__(self, **kwargs): + pass + + def get_entry_node_to_simd_score(self, entry_nodes: List[BaseNode]): + """ + Generates the scores and group indices for the provided entry nodes. + + Args: + entry_nodes (List[BaseNode]): The entry nodes for which scores are to be generated. + + Returns: + A tuple containing the generated scores and group indices. + """ + grouped_indices = { + entry_nodes[0]: [np.arange(i, min(i + ConstantImportanceMetric.simd, ConstantImportanceMetric.first_num_oc)) for i in range(0, ConstantImportanceMetric.first_num_oc, ConstantImportanceMetric.simd)], + entry_nodes[1]: [np.arange(i, min(i + ConstantImportanceMetric.simd, ConstantImportanceMetric.second_num_oc)) for i in range(0, ConstantImportanceMetric.second_num_oc, ConstantImportanceMetric.simd)] + } + + entry_node_to_simd_score = { + entry_nodes[0]: [-np.min(np.arange(i, min(i + ConstantImportanceMetric.simd, ConstantImportanceMetric.first_num_oc))) for i in range(0, ConstantImportanceMetric.first_num_oc, ConstantImportanceMetric.simd)], + entry_nodes[1]: [-np.min(np.arange(i, min(i + ConstantImportanceMetric.simd, ConstantImportanceMetric.second_num_oc))) for i in range(0, ConstantImportanceMetric.second_num_oc, ConstantImportanceMetric.simd)] + } + + return entry_node_to_simd_score, grouped_indices + + +class ConstImportanceMetric(Enum): + CONST = 'const' + + +def add_const_importance_metric(first_num_oc, second_num_oc, simd=1): + """ + Adds the constant importance metric to the global importance metrics dictionary. + + Args: + first_num_oc (int): Number of output channels for the first layer. + second_num_oc (int): Number of output channels for the second layer. + """ + # Set the static attributes for the number of output channels. + ConstantImportanceMetric.first_num_oc = first_num_oc + ConstantImportanceMetric.second_num_oc = second_num_oc + ConstantImportanceMetric.simd = simd + + # Update the global dictionary mapping importance metrics to their corresponding classes. + IMPORTANCE_METRIC_DICT.update({ConstImportanceMetric.CONST: ConstantImportanceMetric}) diff --git a/tests/keras_tests/pruning_tests/feature_networks/networks_tests/__init__.py b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/__init__.py new file mode 100644 index 000000000..807f5e384 --- /dev/null +++ b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== \ No newline at end of file diff --git a/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2d_conv2dtranspose_pruning_test.py b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2d_conv2dtranspose_pruning_test.py new file mode 100644 index 000000000..498361943 --- /dev/null +++ b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2d_conv2dtranspose_pruning_test.py @@ -0,0 +1,98 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf + +import model_compression_toolkit as mct +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc +from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model +from tests.keras_tests.pruning_tests.feature_networks.constant_importance_metric import add_const_importance_metric, \ + ConstImportanceMetric + +from tests.keras_tests.pruning_tests.feature_networks.pruning_keras_feature_test import PruningKerasFeatureTest +from tests.keras_tests.utils import get_layers_from_model_by_type +import numpy as np + +keras = tf.keras +layers = keras.layers + + +class Conv2DtoConv2DTransposePruningTest(PruningKerasFeatureTest): + + def __init__(self, + unit_test, + use_bn=False, + activation_layer=None, + simd=1, + use_constant_importance_metric=True): + super().__init__(unit_test, + input_shape=(8, 8, 3)) + self.use_bn = use_bn + self.activation_layer = activation_layer + self.simd = simd + self.use_constant_importance_metric = use_constant_importance_metric + + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = layers.Conv2D(filters=6, kernel_size=1)(inputs) + if self.use_bn: + x = layers.BatchNormalization()(x) + if self.activation_layer: + x = self.activation_layer(x) + x = layers.Conv2DTranspose(filters=4, kernel_size=1)(x) + x = layers.Conv2D(filters=4, kernel_size=1)(x) + model = keras.Model(inputs=inputs, outputs=x) + return model + + + def get_tpc(self): + tp = generate_test_tp_model({'simd_size': self.simd}) + return generate_keras_tpc(name="simd_test", tp_model=tp) + + def get_pruning_config(self): + if self.use_constant_importance_metric: + add_const_importance_metric(first_num_oc=6, second_num_oc=4, simd=self.simd) + return mct.pruning.PruningConfig(importance_metric=ConstImportanceMetric.CONST) + return super().get_pruning_config() + def get_kpi(self): + # Remove only one group of channels only one parameter should be pruned + return mct.KPI(weights_memory=(self.dense_model_num_params - 1) * 4) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + dense_convtrans_layers = get_layers_from_model_by_type(float_model, layers.Conv2DTranspose) + dense_conv_layers = get_layers_from_model_by_type(float_model, layers.Conv2D) + + prunable_convtrans_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2DTranspose) + prunable_conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) + + is_first_layer_pruned = prunable_conv_layers[0].filters == 6 - self.simd + is_second_layer_pruned = prunable_convtrans_layers[0].filters == 4 - self.simd + + # Make sure only one of layers has been pruned + self.unit_test.assertTrue(is_first_layer_pruned != is_second_layer_pruned) + + # In constant case, the last SIMD channels of the first layer should be pruned: + if self.use_constant_importance_metric: + self.unit_test.assertTrue(is_first_layer_pruned) + self.unit_test.assertTrue(np.all(prunable_conv_layers[0].kernel.numpy() == dense_conv_layers[0].kernel.numpy()[:, :, :, :-self.simd])) + self.unit_test.assertTrue(np.all(prunable_conv_layers[0].bias.numpy() == dense_conv_layers[0].bias.numpy()[:-self.simd])) + # Make sure the only in channel removed is the last channel of the second conv layer + self.unit_test.assertTrue(np.all(prunable_convtrans_layers[0].kernel.numpy() == dense_convtrans_layers[0].kernel.numpy()[:, :, :, :-self.simd])) + self.unit_test.assertTrue(np.all(prunable_convtrans_layers[0].bias.numpy() == dense_convtrans_layers[0].bias.numpy())) + + if is_first_layer_pruned: + self.unit_test.assertTrue(np.all(prunable_conv_layers[1].kernel.numpy() == dense_conv_layers[1].kernel.numpy())) + self.unit_test.assertTrue(np.all(prunable_conv_layers[1].bias.numpy() == dense_conv_layers[1].bias.numpy())) diff --git a/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2d_pruning_test.py b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2d_pruning_test.py new file mode 100644 index 000000000..73915630e --- /dev/null +++ b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2d_pruning_test.py @@ -0,0 +1,98 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf + +import model_compression_toolkit as mct +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc +from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model +from tests.keras_tests.pruning_tests.feature_networks.constant_importance_metric import ConstImportanceMetric, \ + add_const_importance_metric +import numpy as np +from tests.keras_tests.pruning_tests.feature_networks.pruning_keras_feature_test import PruningKerasFeatureTest +from tests.keras_tests.utils import get_layers_from_model_by_type + +keras = tf.keras +layers = keras.layers + + +class Conv2DPruningTest(PruningKerasFeatureTest): + """ + Test a network with two adjacent conv2d and check it's pruned a single group of channels. + """ + + def __init__(self, + unit_test, + use_bn=False, + activation_layer=None, + simd=1, + use_constant_importance_metric=True): + + super().__init__(unit_test, + input_shape=(8, 8, 3)) + self.use_bn = use_bn + self.activation_layer = activation_layer + self.simd = simd + self.use_constant_importance_metric = use_constant_importance_metric + + def get_tpc(self): + tp = generate_test_tp_model({'simd_size': self.simd}) + return generate_keras_tpc(name="simd_test", tp_model=tp) + + def get_pruning_config(self): + if self.use_constant_importance_metric: + add_const_importance_metric(first_num_oc=8, second_num_oc=6, simd=self.simd) + return mct.pruning.PruningConfig(importance_metric=ConstImportanceMetric.CONST) + return super().get_pruning_config() + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = layers.Conv2D(filters=8, kernel_size=1)(inputs) + if self.use_bn: + x = layers.BatchNormalization()(x) + if self.activation_layer: + x = self.activation_layer(x) + x = layers.Conv2D(filters=6, kernel_size=2)(x) + outputs = layers.Conv2D(filters=1, kernel_size=3)(x) + model = keras.Model(inputs=inputs, outputs=outputs) + return model + + def get_kpi(self): + # Remove only one group of channels only one parameter should be pruned + return mct.KPI(weights_memory=(self.dense_model_num_params-1) * 4) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + dense_layers = get_layers_from_model_by_type(float_model, layers.Conv2D) + prunable_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) + + is_first_layer_pruned = prunable_layers[0].filters == 8 - self.simd + is_second_layer_pruned = prunable_layers[1].filters == 6 - self.simd + + # Make sure only one of layers has been pruned + self.unit_test.assertTrue(is_first_layer_pruned != is_second_layer_pruned) + + # In constant case, the last SIMD channels of the first layer should be pruned: + if self.use_constant_importance_metric: + self.unit_test.assertTrue(is_first_layer_pruned) + self.unit_test.assertTrue(np.all(prunable_layers[0].kernel.numpy()==dense_layers[0].kernel.numpy()[:,:,:,:-self.simd])) + self.unit_test.assertTrue(np.all(prunable_layers[0].bias.numpy()==dense_layers[0].bias.numpy()[:-self.simd])) + self.unit_test.assertTrue(np.all(prunable_layers[1].kernel.numpy() == dense_layers[1].kernel.numpy()[:, :, :-self.simd, :])) + self.unit_test.assertTrue(np.all(prunable_layers[1].bias.numpy()==dense_layers[1].bias.numpy())) + + if is_first_layer_pruned: + self.unit_test.assertTrue(np.all(prunable_layers[2].kernel.numpy() == dense_layers[2].kernel.numpy())) + self.unit_test.assertTrue(np.all(prunable_layers[2].bias.numpy() == dense_layers[2].bias.numpy())) + + diff --git a/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2dtranspose_conv2d_pruning_test.py b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2dtranspose_conv2d_pruning_test.py new file mode 100644 index 000000000..ea9afef40 --- /dev/null +++ b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2dtranspose_conv2d_pruning_test.py @@ -0,0 +1,99 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import tensorflow as tf +import numpy as np +import model_compression_toolkit as mct +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc +from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model +from tests.keras_tests.pruning_tests.feature_networks.constant_importance_metric import add_const_importance_metric, \ + ConstImportanceMetric + +from tests.keras_tests.pruning_tests.feature_networks.pruning_keras_feature_test import PruningKerasFeatureTest +from tests.keras_tests.utils import get_layers_from_model_by_type + +keras = tf.keras +layers = keras.layers + + +class Conv2DTransposetoConv2DPruningTest(PruningKerasFeatureTest): + + def __init__(self, + unit_test, + use_bn=False, + activation_layer=None, + simd=1, + use_constant_importance_metric=True): + super().__init__(unit_test, + input_shape=(8, 8, 3)) + self.use_bn = use_bn + self.activation_layer = activation_layer + self.simd = simd + self.use_constant_importance_metric = use_constant_importance_metric + + def get_tpc(self): + tp = generate_test_tp_model({'simd_size': self.simd}) + return generate_keras_tpc(name="simd_test", tp_model=tp) + + def get_pruning_config(self): + if self.use_constant_importance_metric: + add_const_importance_metric(first_num_oc=6, second_num_oc=4, simd=self.simd) + return mct.pruning.PruningConfig(importance_metric=ConstImportanceMetric.CONST) + return super().get_pruning_config() + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = layers.Conv2DTranspose(filters=6, kernel_size=1)(inputs) + if self.use_bn: + x = layers.BatchNormalization()(x) + if self.activation_layer: + x = self.activation_layer(x) + x = layers.Conv2D(filters=4, kernel_size=1)(x) + x = layers.Conv2DTranspose(filters=4, kernel_size=1)(x) + model = keras.Model(inputs=inputs, outputs=x) + return model + + def get_kpi(self): + # Remove only one group of channels only one parameter should be pruned + return mct.KPI(weights_memory=(self.dense_model_num_params - 1) * 4) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + dense_convtrans_layers = get_layers_from_model_by_type(float_model, layers.Conv2DTranspose) + dense_conv_layers = get_layers_from_model_by_type(float_model, layers.Conv2D) + + prunable_convtrans_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2DTranspose) + prunable_conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) + + is_first_layer_pruned = prunable_convtrans_layers[0].filters == 6 - self.simd + is_second_layer_pruned = prunable_conv_layers[0].filters == 4 - self.simd + + # Make sure only one of layers has been pruned + self.unit_test.assertTrue(is_first_layer_pruned != is_second_layer_pruned) + + # In constant case, the last SIMD channels of the first layer should be pruned: + if self.use_constant_importance_metric: + self.unit_test.assertTrue(is_first_layer_pruned) + self.unit_test.assertTrue(np.all(prunable_convtrans_layers[0].kernel.numpy() == dense_convtrans_layers[0].kernel.numpy()[:, :, :-self.simd, :])) + self.unit_test.assertTrue(np.all(prunable_convtrans_layers[0].bias.numpy() == dense_convtrans_layers[0].bias.numpy()[:-self.simd])) + + # Make sure the only in channel removed is the last channel of the second conv layer + self.unit_test.assertTrue(np.all(prunable_conv_layers[0].kernel.numpy() == dense_conv_layers[0].kernel.numpy()[:, :, :-self.simd, :])) + self.unit_test.assertTrue(np.all(prunable_conv_layers[0].bias.numpy() == dense_conv_layers[0].bias.numpy())) + + if is_first_layer_pruned: + self.unit_test.assertTrue(np.all(prunable_convtrans_layers[1].kernel.numpy() == dense_convtrans_layers[1].kernel.numpy())) + self.unit_test.assertTrue(np.all(prunable_convtrans_layers[1].bias.numpy() == dense_convtrans_layers[1].bias.numpy())) + diff --git a/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2dtranspose_pruning_test.py b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2dtranspose_pruning_test.py new file mode 100644 index 000000000..14e1a9a79 --- /dev/null +++ b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/conv2dtranspose_pruning_test.py @@ -0,0 +1,103 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import tensorflow as tf + +import model_compression_toolkit as mct +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc +from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model +from tests.keras_tests.pruning_tests.feature_networks.constant_importance_metric import add_const_importance_metric, \ + ConstImportanceMetric + +from tests.keras_tests.pruning_tests.feature_networks.pruning_keras_feature_test import PruningKerasFeatureTest +from tests.keras_tests.utils import get_layers_from_model_by_type +import numpy as np + +keras = tf.keras +layers = keras.layers + + +class Conv2DTransposePruningTest(PruningKerasFeatureTest): + """ + Test a network with two adjacent dense and check it's pruned for a target compression ratio. + """ + + def __init__(self, + unit_test, + use_bn=False, + activation_layer=None, + simd=1, + use_constant_importance_metric=True): + super().__init__(unit_test, + input_shape=(8, 8, 3)) + self.use_bn = use_bn + self.activation_layer = activation_layer + self.simd = simd + self.use_constant_importance_metric = use_constant_importance_metric + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = layers.Conv2DTranspose(filters=6, kernel_size=1)(inputs) + if self.use_bn: + x = layers.BatchNormalization()(x) + if self.activation_layer: + x = self.activation_layer(x) + x = layers.Conv2DTranspose(filters=4, kernel_size=1)(x) + x = layers.Conv2DTranspose(filters=4, kernel_size=1)(x) + model = keras.Model(inputs=inputs, outputs=x) + return model + + def get_tpc(self): + tp = generate_test_tp_model({'simd_size': self.simd}) + return generate_keras_tpc(name="simd_test", tp_model=tp) + + def get_pruning_config(self): + if self.use_constant_importance_metric: + add_const_importance_metric(first_num_oc=6, second_num_oc=4, simd=self.simd) + return mct.pruning.PruningConfig(importance_metric=ConstImportanceMetric.CONST) + return super().get_pruning_config() + + def get_kpi(self): + # Remove only one group of channels only one parameter should be pruned + return mct.KPI(weights_memory=(self.dense_model_num_params-1) * 4) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + dense_layers = get_layers_from_model_by_type(float_model, layers.Conv2DTranspose) + prunable_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2DTranspose) + + is_first_layer_pruned = prunable_layers[0].filters == 6 - self.simd + is_second_layer_pruned = prunable_layers[1].filters == 4 - self.simd + + # Make sure only one of layers has been pruned + self.unit_test.assertTrue(is_first_layer_pruned != is_second_layer_pruned) + + # In constant case, the last SIMD channels of the first layer should be pruned: + if self.use_constant_importance_metric: + self.unit_test.assertTrue(is_first_layer_pruned) + self.unit_test.assertTrue(np.all(prunable_layers[0].kernel.numpy()==dense_layers[0].kernel.numpy()[:,:,:-self.simd,:])) + self.unit_test.assertTrue(np.all(prunable_layers[0].bias.numpy()==dense_layers[0].bias.numpy()[:-self.simd])) + + # Make sure the only in channel removed is the last channel of the second conv layer + self.unit_test.assertTrue(np.all(prunable_layers[1].kernel.numpy() == dense_layers[1].kernel.numpy()[:, :, :, :-self.simd])) + self.unit_test.assertTrue(np.all(prunable_layers[1].bias.numpy()==dense_layers[1].bias.numpy())) + + if is_first_layer_pruned: + self.unit_test.assertTrue(np.all(prunable_layers[2].kernel.numpy() == dense_layers[2].kernel.numpy())) + self.unit_test.assertTrue(np.all(prunable_layers[2].bias.numpy() == dense_layers[2].bias.numpy())) + + + + diff --git a/tests/keras_tests/pruning_tests/feature_networks/networks_tests/dense_pruning_test.py b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/dense_pruning_test.py new file mode 100644 index 000000000..41c27643d --- /dev/null +++ b/tests/keras_tests/pruning_tests/feature_networks/networks_tests/dense_pruning_test.py @@ -0,0 +1,100 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import tensorflow as tf + +import model_compression_toolkit as mct +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc +from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model +from tests.keras_tests.pruning_tests.feature_networks.constant_importance_metric import add_const_importance_metric, \ + ConstImportanceMetric + +from tests.keras_tests.pruning_tests.feature_networks.pruning_keras_feature_test import PruningKerasFeatureTest +from tests.keras_tests.utils import get_layers_from_model_by_type +import numpy as np + +keras = tf.keras +layers = keras.layers + + +class DensePruningTest(PruningKerasFeatureTest): + """ + Test a network with two adjacent dense and check it's pruned for a target compression ratio. + """ + + def __init__(self, + unit_test, + use_bn=False, + activation_layer=None, + simd=1, + use_constant_importance_metric=True): + + super().__init__(unit_test, + input_shape=(8, 8, 3)) + self.use_bn = use_bn + self.activation_layer = activation_layer + self.simd = simd + self.use_constant_importance_metric = use_constant_importance_metric + + def get_tpc(self): + tp = generate_test_tp_model({'simd_size': self.simd}) + return generate_keras_tpc(name="simd_test", tp_model=tp) + + def get_pruning_config(self): + if self.use_constant_importance_metric: + add_const_importance_metric(first_num_oc=10, second_num_oc=6, simd=self.simd) + return mct.pruning.PruningConfig(importance_metric=ConstImportanceMetric.CONST) + return super().get_pruning_config() + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = layers.Dense(units=10)(inputs) + if self.use_bn: + x = layers.BatchNormalization()(x) + if self.activation_layer: + x = self.activation_layer(x) + x = layers.Dense(units=6)(x) + outputs = layers.Dense(units=6)(x) + model = keras.Model(inputs=inputs, outputs=outputs) + return model + + def get_kpi(self): + # Remove only one group of channels only one parameter should be pruned + return mct.KPI(weights_memory=(self.dense_model_num_params - 1) * 4) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + dense_layers = get_layers_from_model_by_type(float_model, layers.Dense) + prunable_layers = get_layers_from_model_by_type(quantized_model, layers.Dense) + + is_first_layer_pruned = prunable_layers[0].units == 10 - self.simd + is_second_layer_pruned = prunable_layers[1].units == 6 - self.simd + + # Make sure only one of layers has been pruned + self.unit_test.assertTrue(is_first_layer_pruned != is_second_layer_pruned) + + # In constant case, the last SIMD channels of the first layer should be pruned: + if self.use_constant_importance_metric: + self.unit_test.assertTrue(is_first_layer_pruned) + self.unit_test.assertTrue(np.all(prunable_layers[0].kernel.numpy() == dense_layers[0].kernel.numpy()[:, :-self.simd])) + self.unit_test.assertTrue(np.all(prunable_layers[0].bias.numpy() == dense_layers[0].bias.numpy()[:-self.simd])) + + # Make sure the only in channel removed is the last channel of the second dense layer + self.unit_test.assertTrue(np.all(prunable_layers[1].kernel.numpy() == dense_layers[1].kernel.numpy()[:-self.simd, :])) + self.unit_test.assertTrue(np.all(prunable_layers[1].bias.numpy() == dense_layers[1].bias.numpy())) + + if is_first_layer_pruned: + self.unit_test.assertTrue(np.all(prunable_layers[2].kernel.numpy() == dense_layers[2].kernel.numpy())) + self.unit_test.assertTrue(np.all(prunable_layers[2].bias.numpy() == dense_layers[2].bias.numpy())) diff --git a/tests/keras_tests/pruning_tests/feature_networks/pruning_keras_feature_test.py b/tests/keras_tests/pruning_tests/feature_networks/pruning_keras_feature_test.py new file mode 100644 index 000000000..f39c93575 --- /dev/null +++ b/tests/keras_tests/pruning_tests/feature_networks/pruning_keras_feature_test.py @@ -0,0 +1,73 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import model_compression_toolkit as mct +from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig +from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest +import numpy as np + +class PruningKerasFeatureTest(BaseKerasFeatureNetworkTest): + def __init__(self, + unit_test, + num_calibration_iter=1, + val_batch_size=1, + num_of_inputs=1, + input_shape=(8, 8, 3)): + + super().__init__(unit_test=unit_test, + val_batch_size=val_batch_size, + num_calibration_iter=num_calibration_iter, + num_of_inputs=num_of_inputs, + input_shape=input_shape) + + self.dense_model_num_params = None + + def get_pruning_config(self): + return PruningConfig(num_score_approximations=2) + + def run_test(self): + feature_networks = self.create_networks() + feature_networks = feature_networks if isinstance(feature_networks, list) else [feature_networks] + for model_float in feature_networks: + self.dense_model_num_params=sum([l.count_params() for l in model_float.layers]) + pruned_model, pruning_info = mct.pruning.keras_pruning_experimental(model=model_float, + target_kpi=self.get_kpi(), + representative_data_gen=self.representative_data_gen_experimental, + pruning_config=self.get_pruning_config(), + target_platform_capabilities=self.get_tpc()) + + self.pruned_model_num_params=sum([l.count_params() for l in pruned_model.layers]) + + ### Test inference ## + input_tensor = self.representative_data_gen() + pruned_outputs = pruned_model(input_tensor) + if self.pruned_model_num_params == self.dense_model_num_params: + dense_outputs = model_float(input_tensor) + self.unit_test.assertTrue(np.sum(np.abs(dense_outputs-pruned_outputs)) == 0, f"If model is not pruned, " + f"predictions should be identical, but found difference between predictions") + + self.unit_test.assertTrue(pruned_model.output_shape == model_float.output_shape, + f"Pruned model should have the same output shape as dense model," + f"but dense model output shape is {model_float.output_shape}," + f"and pruned model output shape is {pruned_model.output_shape}") + + for dense_layer, pruned_layer in zip(model_float.layers, pruned_model.layers): + self.unit_test.assertTrue(type(pruned_layer)==type(dense_layer), f"type of layers and their orders should be the same," + f"but {type(dense_layer)} is not {type(pruned_layer)}") + + self.compare(pruned_model, + model_float, + input_x=input_tensor, + quantization_info=pruning_info) + diff --git a/tests/keras_tests/pruning_tests/feature_networks/test_pruning_feature_networks.py b/tests/keras_tests/pruning_tests/feature_networks/test_pruning_feature_networks.py new file mode 100644 index 000000000..b5c2dff04 --- /dev/null +++ b/tests/keras_tests/pruning_tests/feature_networks/test_pruning_feature_networks.py @@ -0,0 +1,112 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import unittest + +from tests.keras_tests.pruning_tests.feature_networks.networks_tests.conv2d_conv2dtranspose_pruning_test import \ + Conv2DtoConv2DTransposePruningTest +from tests.keras_tests.pruning_tests.feature_networks.networks_tests.conv2d_pruning_test import Conv2DPruningTest + +from tests.keras_tests.pruning_tests.feature_networks.networks_tests.conv2dtranspose_conv2d_pruning_test import \ + Conv2DTransposetoConv2DPruningTest +from tests.keras_tests.pruning_tests.feature_networks.networks_tests.conv2dtranspose_pruning_test import \ + Conv2DTransposePruningTest +from tests.keras_tests.pruning_tests.feature_networks.networks_tests.dense_pruning_test import DensePruningTest +import keras + +layers = keras.layers + + +class PruningFeatureNetworksTest(unittest.TestCase): + + def test_conv2d_pruning(self): + Conv2DPruningTest(self).run_test() + Conv2DPruningTest(self, use_bn=True).run_test() + Conv2DPruningTest(self, use_bn=True, activation_layer=layers.ReLU()).run_test() + Conv2DPruningTest(self, use_bn=True, activation_layer=layers.Softmax()).run_test() + Conv2DPruningTest(self, use_bn=True, activation_layer=layers.PReLU()).run_test() + Conv2DPruningTest(self, simd=2).run_test() + Conv2DPruningTest(self, use_bn=True, simd=2).run_test() + Conv2DPruningTest(self, use_bn=True, activation_layer=layers.ReLU(), simd=2).run_test() + Conv2DPruningTest(self, use_bn=True, activation_layer=layers.Softmax(), simd=2).run_test() + Conv2DPruningTest(self, use_bn=True, activation_layer=layers.PReLU(), simd=2).run_test() + + # Use dummy LFH + Conv2DPruningTest(self, use_constant_importance_metric=False).run_test() + Conv2DPruningTest(self, simd=2, use_constant_importance_metric=False).run_test() + + def test_dense_pruning(self): + DensePruningTest(self).run_test() + DensePruningTest(self, use_bn=True).run_test() + DensePruningTest(self, use_bn=True, activation_layer=layers.ReLU()).run_test() + DensePruningTest(self, use_bn=True, activation_layer=layers.Softmax()).run_test() + DensePruningTest(self, use_bn=True, activation_layer=layers.PReLU()).run_test() + DensePruningTest(self, simd=2).run_test() + DensePruningTest(self, use_bn=True, simd=2).run_test() + DensePruningTest(self, use_bn=True, activation_layer=layers.ReLU(), simd=2).run_test() + DensePruningTest(self, use_bn=True, activation_layer=layers.Softmax(), simd=2).run_test() + DensePruningTest(self, use_bn=True, activation_layer=layers.PReLU(), simd=2).run_test() + # Use dummy LFH + DensePruningTest(self, use_constant_importance_metric=False).run_test() + DensePruningTest(self, simd=2, use_constant_importance_metric=False).run_test() + + def test_conv2dtranspose_pruning(self): + Conv2DTransposePruningTest(self, ).run_test() + Conv2DTransposePruningTest(self, use_bn=True).run_test() + Conv2DTransposePruningTest(self, use_bn=True, activation_layer=layers.ReLU()).run_test() + Conv2DTransposePruningTest(self, use_bn=True, activation_layer=layers.Softmax()).run_test() + Conv2DTransposePruningTest(self, use_bn=True, activation_layer=layers.PReLU()).run_test() + Conv2DTransposePruningTest(self, simd=2).run_test() + Conv2DTransposePruningTest(self, use_bn=True, simd=2).run_test() + Conv2DTransposePruningTest(self, use_bn=True, activation_layer=layers.ReLU(), simd=2).run_test() + Conv2DTransposePruningTest(self, use_bn=True, activation_layer=layers.Softmax(), simd=2).run_test() + Conv2DTransposePruningTest(self, use_bn=True, activation_layer=layers.PReLU(), simd=2).run_test() + # Use dummy LFH + Conv2DTransposePruningTest(self, use_constant_importance_metric=False).run_test() + Conv2DTransposePruningTest(self, simd=2, use_constant_importance_metric=False).run_test() + + def test_conv2d_conv2dtranspose_pruning(self): + Conv2DtoConv2DTransposePruningTest(self).run_test() + Conv2DtoConv2DTransposePruningTest(self, use_bn=True).run_test() + Conv2DtoConv2DTransposePruningTest(self, use_bn=True, activation_layer=layers.ReLU()).run_test() + Conv2DtoConv2DTransposePruningTest(self, use_bn=True, activation_layer=layers.Softmax()).run_test() + Conv2DtoConv2DTransposePruningTest(self, use_bn=True, activation_layer=layers.PReLU()).run_test() + Conv2DtoConv2DTransposePruningTest(self, simd=2).run_test() + Conv2DtoConv2DTransposePruningTest(self, use_bn=True, simd=2).run_test() + Conv2DtoConv2DTransposePruningTest(self, use_bn=True, activation_layer=layers.ReLU(), simd=2).run_test() + Conv2DtoConv2DTransposePruningTest(self, use_bn=True, activation_layer=layers.Softmax(), simd=2).run_test() + Conv2DtoConv2DTransposePruningTest(self, use_bn=True, activation_layer=layers.PReLU(), simd=2).run_test() + # Use dummy LFH + Conv2DtoConv2DTransposePruningTest(self, use_constant_importance_metric=False).run_test() + Conv2DtoConv2DTransposePruningTest(self, simd=2, use_constant_importance_metric=False).run_test() + + def test_conv2dtranspose_conv2d_pruning(self): + Conv2DTransposetoConv2DPruningTest(self).run_test() + Conv2DTransposetoConv2DPruningTest(self, use_bn=True).run_test() + Conv2DTransposetoConv2DPruningTest(self, use_bn=True, activation_layer=layers.ReLU()).run_test() + Conv2DTransposetoConv2DPruningTest(self, use_bn=True, activation_layer=layers.Softmax()).run_test() + Conv2DTransposetoConv2DPruningTest(self, use_bn=True, activation_layer=layers.PReLU()).run_test() + Conv2DTransposetoConv2DPruningTest(self, simd=2).run_test() + Conv2DTransposetoConv2DPruningTest(self, use_bn=True, simd=2).run_test() + Conv2DTransposetoConv2DPruningTest(self, use_bn=True, activation_layer=layers.ReLU(), simd=2).run_test() + Conv2DTransposetoConv2DPruningTest(self, use_bn=True, activation_layer=layers.Softmax(), simd=2).run_test() + Conv2DTransposetoConv2DPruningTest(self, use_bn=True, activation_layer=layers.PReLU(), simd=2).run_test() + # Use dummy LFH + Conv2DTransposetoConv2DPruningTest(self, use_constant_importance_metric=False).run_test() + Conv2DTransposetoConv2DPruningTest(self, simd=2, use_constant_importance_metric=False).run_test() + +if __name__ == '__main__': + unittest.main() diff --git a/tests/keras_tests/pruning_tests/random_importance_metric.py b/tests/keras_tests/pruning_tests/random_importance_metric.py new file mode 100644 index 000000000..75da46c8c --- /dev/null +++ b/tests/keras_tests/pruning_tests/random_importance_metric.py @@ -0,0 +1,63 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +from typing import Callable, List + +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common import Graph, BaseNode +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.pruning.channels_grouping import ChannelGrouping +from model_compression_toolkit.core.common.pruning.importance_metrics.base_importance_metric import BaseImportanceMetric +import numpy as np + +from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig +from model_compression_toolkit.core.keras.constants import KERNEL + + +class RandomImportanceMetric(BaseImportanceMetric): + + def __init__(self, + graph: Graph, + representative_data_gen: Callable, + fw_impl: FrameworkImplementation, + pruning_config: PruningConfig, + fw_info: FrameworkInfo): + self.float_graph = graph + self.representative_data_gen = representative_data_gen + self.fw_impl = fw_impl + self.pruning_config = pruning_config + self.fw_info = fw_info + + + def get_entry_node_to_simd_score(self, entry_nodes: List[BaseNode]): + entry_node_to_score = self._get_entry_node_to_score(entry_nodes) + self.channel_grouping = ChannelGrouping(prunable_nodes=entry_nodes, + fw_info=self.fw_info) + self.channel_grouping.group_scores_by_simd_groups(entry_node_to_score) + grouped_indices = self.channel_grouping.simd_groups_indices + entry_node_to_simd_score = {} + for node, trace in entry_node_to_score.items(): + trace_by_group = [np.sum(trace[g]) for g in grouped_indices[node]] + entry_node_to_simd_score[node]=np.asarray(trace_by_group) + return entry_node_to_simd_score, grouped_indices + + + def _get_entry_node_to_score(self, sections_input_nodes: List[BaseNode]): + random_scores = [np.random.random( + node.get_weights_by_keys(KERNEL).shape[self.fw_info.kernel_channels_mapping.get(node.type)[0]]) + for node in sections_input_nodes] + entry_node_to_score = {node: scores for node, scores in zip(sections_input_nodes, random_scores)} + return entry_node_to_score \ No newline at end of file diff --git a/tests/keras_tests/pruning_tests/test_memory_calculator.py b/tests/keras_tests/pruning_tests/test_memory_calculator.py new file mode 100644 index 000000000..6bf2ed57d --- /dev/null +++ b/tests/keras_tests/pruning_tests/test_memory_calculator.py @@ -0,0 +1,87 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph + +import unittest +import model_compression_toolkit as mct +from model_compression_toolkit.core.common.pruning.memory_calculator import MemoryCalculator +from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO +from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation +from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph + +import keras + +layers = keras.layers +import numpy as np + + +class TestParameterCounter(unittest.TestCase): + # TODO: Extend it to more layers and scenarios + + def representative_dataset(self, in_shape=(1,8,8,3)): + for _ in range(1): + yield [np.random.randn(*in_shape)] + + def test_conv_layer(self): + # Define the layer + out_channels = 2 + in_channels = 1 + kernel_size = 3 + use_bias=True + + inputs = layers.Input(shape=(8, 8, in_channels)) + x = layers.Conv2D(filters=out_channels, kernel_size=kernel_size, use_bias=use_bias)(inputs) + model = keras.Model(inputs=inputs, outputs=x) + + fw_info = DEFAULT_KERAS_INFO + fw_impl = PruningKerasImplementation() + tpc = mct.get_target_platform_capabilities('tensorflow', 'imx500') + + # Convert the original Keras model to an internal graph representation. + float_graph = read_model_to_graph(model, + self.representative_dataset, + tpc, + DEFAULT_KERAS_INFO, + fw_impl) + + # Apply quantization configuration to the graph. This step is necessary even when not quantizing, + # as it prepares the graph for the pruning process. + float_graph_with_compression_config = set_quantization_configuration_to_graph(float_graph, + quant_config=mct.DEFAULTCONFIG, + mixed_precision_enable=False) + + + self.memory_calculator = MemoryCalculator(graph=float_graph_with_compression_config, + fw_info=fw_info, + fw_impl=fw_impl) + + # masks = {list(float_graph_with_compression_config.nodes)[0]} + counted_params = self.memory_calculator.get_pruned_graph_num_params(masks=None, + include_padded_channels=tpc.is_simd_padding) + + # Calculate expected number of parameters + simd_groups = np.ceil(out_channels/32.) + expected_params = 32 * simd_groups * (in_channels * kernel_size * kernel_size + int(use_bias)) + self.assertEqual(counted_params, expected_params) + + counted_params = self.memory_calculator.get_pruned_graph_num_params(masks=None, + include_padded_channels=False) + + # Calculate expected number of parameters + expected_params = out_channels * (in_channels * kernel_size * kernel_size + int(use_bias)) + self.assertEqual(counted_params, expected_params) + + diff --git a/tests/keras_tests/pruning_tests/test_pretrained_models.py b/tests/keras_tests/pruning_tests/test_pretrained_models.py new file mode 100644 index 000000000..5a0696a88 --- /dev/null +++ b/tests/keras_tests/pruning_tests/test_pretrained_models.py @@ -0,0 +1,189 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tempfile +from enum import Enum + +import unittest + +import tensorflow as tf + +import model_compression_toolkit as mct +import numpy as np +from packaging import version + +from model_compression_toolkit.constants import FP32_BYTES_PER_PARAMETER +from model_compression_toolkit.core.common.pruning.importance_metrics.importance_metric_factory import IMPORTANCE_METRIC_DICT +from tests.keras_tests.pruning_tests.random_importance_metric import RandomImportanceMetric + +keras = tf.keras +layers = keras.layers + +NUM_PRUNING_RATIOS = 1 + +class TestImportanceMetric(Enum): + RANDOM = 'random' + +IMPORTANCE_METRIC_DICT.update({TestImportanceMetric.RANDOM: RandomImportanceMetric}) + +class PruningPretrainedModelsTest(unittest.TestCase): + def representative_dataset(self, in_shape=(1,224,224,3)): + for _ in range(1): + yield [np.random.randn(*in_shape)] + + def test_rn50_pruning(self): + # Can not be found in tf2.12 + if version.parse(tf.__version__) >= version.parse("2.13"): + from keras.applications.resnet50 import ResNet50 + dense_model = ResNet50() + target_crs = np.linspace(0.5, 1, NUM_PRUNING_RATIOS) + for cr in target_crs: + self.run_test(cr, dense_model) + + def test_efficientnetb0_pruning(self): + from keras.applications.efficientnet import EfficientNetB0 + dense_model = EfficientNetB0() + target_crs = np.linspace(0.8, 1, NUM_PRUNING_RATIOS) + for cr in target_crs: + self.run_test(cr, dense_model) + + + def test_vgg16_pruning(self): + from keras.applications.vgg16 import VGG16 + dense_model = VGG16() + target_crs = np.linspace(0.5, 1, NUM_PRUNING_RATIOS) + for cr in target_crs: + self.run_test(cr, dense_model) + + + def test_mobilenet_pruning(self): + from keras.applications.mobilenet import MobileNet + dense_model = MobileNet() + target_crs = np.linspace(0.55, 1, NUM_PRUNING_RATIOS) + for cr in target_crs: + self.run_test(cr, dense_model) + + def test_mobilenetv2_pruning(self): + from keras.applications.mobilenet_v2 import MobileNetV2 + dense_model = MobileNetV2() + target_crs = np.linspace(0.5, 1, NUM_PRUNING_RATIOS) + for cr in target_crs: + self.run_test(cr, dense_model) + + def test_densenet_pruning(self): + from keras.applications.densenet import DenseNet121 + dense_model = DenseNet121() + target_crs = np.linspace(0.5, 1, NUM_PRUNING_RATIOS) + for cr in target_crs: + self.run_test(cr, dense_model) + + + def test_vgg19_pruning(self): + from keras.applications.vgg19 import VGG19 + dense_model = VGG19() + target_crs = np.linspace(0.5, 1, NUM_PRUNING_RATIOS) + for cr in target_crs: + self.run_test(cr, dense_model) + + + def _dummy_retrain(self, model, ds): + # Compile the model with a loss function, optimizer, and metric to monitor + model.compile(optimizer='adam', + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + + # Train the model for one epoch using the dummy dataset + model.fit(ds, epochs=1) + return model + + def run_test(self, cr, dense_model, test_retraining=False): + """ + Runs a pruning test on a pre-trained model with a specified compression rate (cr). + + Args: + cr (float): The target compression rate (ratio of remaining parameters). + dense_model (Model): The pre-trained Keras model to be pruned. + test_retraining (bool): If True, retrain the pruned model on dummy data to test stability. + + This function calculates the number of parameters in the dense model, performs pruning to achieve + the desired compression rate, and validates the actual compression rate achieved. It also tests + if the outputs of the pruned model are similar to the dense model, and ensures that pruned layers + respect the importance scores. If `test_retraining` is True, it further validates the model's + performance after retraining. + """ + # Calculate the number of parameters in the dense model. + dense_nparams = sum([l.count_params() for l in dense_model.layers]) + + # Perform pruning on the dense model. + pruned_model, pruning_info = mct.pruning.keras_pruning_experimental( + model=dense_model, + target_kpi=mct.KPI(weights_memory=dense_nparams * FP32_BYTES_PER_PARAMETER * cr), + representative_data_gen=self.representative_dataset, + pruning_config=mct.pruning.PruningConfig( + num_score_approximations=1, + importance_metric=TestImportanceMetric.RANDOM) + ) + + # Calculate the actual compression rate achieved after pruning. + pruned_nparams = sum([l.count_params() for l in pruned_model.layers]) + actual_cr = pruned_nparams / dense_nparams + print(f"Target remaining cr: {cr * 100}, Actual remaining cr: {actual_cr * 100}") + + input_tensor = next(self.representative_dataset())[0] + pruned_outputs = pruned_model(input_tensor) + + # Optionally, retrain the pruned model (using dummy data for 1 epoch) and check it + # predicts differently than before retraining. + if test_retraining: + ds = create_dummy_dataset() + retrained_model = self._dummy_retrain(pruned_model, ds) + retrained_outputs = retrained_model(input_tensor) + self.assertTrue(np.sum(np.abs(pruned_outputs - retrained_outputs)) != 0, f"Expected after retraining to have different predictions but are the same") + + # Ensure pruned layers had lower importance scores than the channels + # that remained. + for layer_name, layer_mask in pruning_info.pruning_masks.items(): + if 0 in layer_mask: + layer_scores = pruning_info.importance_scores[layer_name] + min_score_remained = min(layer_scores[layer_mask.astype("bool")]) + max_score_removed = max(layer_scores[(1 - layer_mask).astype("bool")]) + self.assertTrue(max_score_removed <= min_score_remained, + f"Expected remaining channels to have higher scores" + f"than pruned channels but found remained channel with score" + f"{min_score_remained} and found pruned channel with" + f"score {max_score_removed}") + + # Validate that the actual compression rate does not exceed the target compression rate. + self.assertTrue(actual_cr <= cr, + f"Expected the actual compression rate: {actual_cr} to not exceed the target compression " + f"rate: {cr}") + + +# Function to generate an infinite stream of dummy images and labels +def dummy_data_generator(): + image = np.random.random((224, 224, 3)).astype(np.float32) + label = np.random.randint(0, 2) + yield image, label + +# Create a Dataset object that returns the dummy data +def create_dummy_dataset(): + dummy_dataset = tf.data.Dataset.from_generator( + dummy_data_generator, + output_signature=( + tf.TensorSpec(shape=(224, 224, 3), dtype=tf.float32), + tf.TensorSpec(shape=(), dtype=tf.int32) + ) + ) + return dummy_dataset.batch(1) diff --git a/tests/keras_tests/pruning_tests/test_pruning_info.py b/tests/keras_tests/pruning_tests/test_pruning_info.py new file mode 100644 index 000000000..c2d70fe85 --- /dev/null +++ b/tests/keras_tests/pruning_tests/test_pruning_info.py @@ -0,0 +1,58 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import unittest +import numpy as np +import model_compression_toolkit as mct +from model_compression_toolkit.core.common.pruning.pruning_info import unroll_simd_scores_to_per_channel_scores + + +class TestPruningInfo(unittest.TestCase): + + def setUp(self): + # Setup some mock pruning masks and importance scores + self.mock_pruning_masks = {"Layer1": np.array([1, 0, 1]), + "Layer2": np.array([0, 1])} + self.mock_importance_scores = {"Layer1": np.array([0.5, 0.3, 0.7]), + "Layer2": np.array([0.2, 0.8])} + self.pruning_info = mct.pruning.PruningInfo(self.mock_pruning_masks, + self.mock_importance_scores) + + def test_get_pruning_mask(self): + # Test to check if the correct pruning masks are returned + self.assertEqual(self.pruning_info.pruning_masks, self.mock_pruning_masks) + + def test_get_importance_score(self): + # Test to check if the correct importance scores are returned + self.assertEqual(self.pruning_info.importance_scores, self.mock_importance_scores) + + +class TestUnrollSIMDScores(unittest.TestCase): + + def test_unroll_simd_scores(self): + # Setup mock SIMD scores and group indices + simd_scores = {"Layer1": np.array([0.2, 0.4, 0.6])} + simd_groups_indices = {"Layer1": [np.array([4, 1]), np.array([2, 3]), np.array([0])]} + + # Expected output + expected_scores = {"Layer1": np.array([0.6, 0.2, 0.4, 0.4, 0.2])} + + # Test the unroll_simd_scores_to_per_channel_scores function + result = unroll_simd_scores_to_per_channel_scores(simd_scores, simd_groups_indices) + self.assertTrue(np.array_equal(result["Layer1"], expected_scores["Layer1"])) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_suite.py b/tests/test_suite.py index 7163da948..94c18abbb 100644 --- a/tests/test_suite.py +++ b/tests/test_suite.py @@ -21,7 +21,6 @@ from tests.common_tests.function_tests.test_collectors_manipulation import TestCollectorsManipulations from tests.common_tests.function_tests.test_folder_image_loader import TestFolderLoader # ---------------- Individual test suites -from model_compression_toolkit.constants import FOUND_ONNX from tests.common_tests.function_tests.test_histogram_collector import TestHistogramCollector from tests.common_tests.function_tests.test_kpi_object import TestKPIObject from tests.common_tests.function_tests.test_threshold_selection import TestThresholdSelection @@ -72,6 +71,9 @@ from tests.keras_tests.function_tests.test_gptq_soft_quantizer import TestGPTQSoftQuantizer as keras_gptq_soft_quantizer_test from tests.keras_tests.function_tests.test_activation_quantization_holder_gptq import TestGPTQModelBuilderWithActivationHolder from tests.data_generation_tests.keras.test_keras_data_generation_runner import KerasDataGenerationTestRunner + from tests.keras_tests.pruning_tests.test_memory_calculator import TestParameterCounter + from tests.keras_tests.pruning_tests.test_pretrained_models import PruningPretrainedModelsTest + from tests.keras_tests.pruning_tests.feature_networks.test_pruning_feature_networks import PruningFeatureNetworksTest if found_pytorch: @@ -105,6 +107,9 @@ # Add TF tests only if tensorflow is installed if found_tf: + suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestParameterCounter)) + suiteList.append(unittest.TestLoader().loadTestsFromTestCase(PruningPretrainedModelsTest)) + suiteList.append(unittest.TestLoader().loadTestsFromTestCase(PruningFeatureNetworksTest)) suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestHessianInfoCalculatorWeights)) suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestHessianInfoCalculatorActivation)) suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestHessianService)) diff --git a/tutorials/notebooks/example_keras_pruning.py b/tutorials/notebooks/example_keras_pruning.py new file mode 100644 index 000000000..733d92944 --- /dev/null +++ b/tutorials/notebooks/example_keras_pruning.py @@ -0,0 +1,98 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import argparse +import keras.models + +from keras.applications.resnet50 import ResNet50 +import tensorflow as tf + +import model_compression_toolkit as mct +import tempfile +import numpy as np +import cv2 + + +RESIZE_SCALE = 256 / 224 +SIZE = 224 + +def resize(x): + resize_side = max(RESIZE_SCALE * SIZE / x.shape[0], RESIZE_SCALE * SIZE / x.shape[1]) + height_tag = int(np.round(resize_side * x.shape[0])) + width_tag = int(np.round(resize_side * x.shape[1])) + resized_img = cv2.resize(x, (width_tag, height_tag)) + offset_height = int((height_tag - SIZE) / 2) + offset_width = int((width_tag - SIZE) / 2) + cropped_img = resized_img[offset_height:offset_height + SIZE, offset_width:offset_width + SIZE] + return cropped_img + + +def count_model_params(model: keras.models.Model) -> int: + # Function to count the total number of parameters in a given Keras model. + return sum([l.count_params() for l in model.layers]) + +def argument_handler(): + parser = argparse.ArgumentParser() + parser.add_argument('--representative_dataset_dir', type=str, help='Folder path for the representative dataset.') + parser.add_argument('--batch_size', type=int, default=50, help='Batch size for the representative data.') + parser.add_argument('--num_score_approximations', type=int, default=32, + help='Number of scores to estimate the importance of each channel.') + parser.add_argument('--compression_rate', type=float, help='Compression rate to remove from the dense model.') + + return parser.parse_args() + + +if __name__ == '__main__': + args = argument_handler() + + # Create a function to generate representative data used for channels importance approximation. + image_data_loader = mct.core.FolderImageLoader(args.representative_dataset_dir, + preprocessing=[resize, + tf.keras.applications.resnet50.preprocess_input], + batch_size=args.batch_size) + + def representative_data_gen() -> list: + yield [image_data_loader.sample()] + + + # Retrieve the target platform capabilities which include the SIMD size configuration for each layer. + target_platform_cap = mct.get_target_platform_capabilities('tensorflow', + 'default') + + # Load a dense ResNet50 model for pruning. Compute the number of params to + # initialize the KPI to constraint the memory footprint of the pruned model's weights. + dense_model = ResNet50() + dense_nparams = count_model_params(dense_model) + print(f"Model has {dense_nparams} parameters.") + kpi = mct.KPI(weights_memory=dense_nparams * 4 * args.compression_rate) + + # Create PruningConfig with the number of approximations MCT will compute as importance metric + # for each channel when using LFH metric to set scores for each output channel that can be removed. + pruning_config = mct.pruning.PruningConfig(num_score_approximations=args.num_score_approximations) + + # Prune the model. + pruned_model, pruning_info = mct.pruning.keras_pruning_experimental(model=dense_model, + target_kpi=kpi, + representative_data_gen=representative_data_gen, + target_platform_capabilities=target_platform_cap, + pruning_config=pruning_config) + + # Count number of params in the pruned model and save it. + pruned_nparams = count_model_params(pruned_model) + print(f"Pruned model has {pruned_nparams} parameters.") + _, keras_file_path = tempfile.mkstemp('.keras') + print(f"Saving pruned model: {keras_file_path}") + keras.models.save_model(pruned_model, keras_file_path) + diff --git a/tutorials/notebooks/example_keras_pruning_mnist.ipynb b/tutorials/notebooks/example_keras_pruning_mnist.ipynb new file mode 100644 index 000000000..9abd82416 --- /dev/null +++ b/tutorials/notebooks/example_keras_pruning_mnist.ipynb @@ -0,0 +1,439 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Structured Pruning of a Fully-Connected Keras Model\n", + "\n", + "[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/example_keras_pruning_mnist.ipynb) \n", + "\n", + "Welcome to this tutorial, where we will guide you through training, pruning, and retraining a fully connected Keras model. We'll begin by constructing and training a simple neural network using the Keras framework. Following this, we will introduce and apply model pruning using MCT to reduce the size of our network. Finally, we'll retrain our pruned model to recover its degraded performance due to the pruning process.\n", + "\n", + "\n", + "## Installing TensorFlow and Model Compression Toolkit\n", + "\n", + "We start by setting up our environment by installing TensorFlow and Model Compression Toolkit and importing them." + ], + "metadata": { + "id": "UJDzewEYfSN5" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install model-compression-toolkit \n", + "!pip install tensorflow" + ], + "metadata": { + "id": "xTvVA__4NItc" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "id": "Q2bAksKtM0ca" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_datasets as tfds\n", + "import model_compression_toolkit as mct" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Loading and Preprocessing MNIST\n", + "\n", + "Let's create a function to retrive the train and test parts of MNIST dataset including preprocessing:" + ], + "metadata": { + "id": "tW1xcK_Kf4F_" + } + }, + { + "cell_type": "code", + "source": [ + "def load_and_preprocess_mnist():\n", + " (ds_train, ds_test), ds_info = tfds.load(\n", + " 'mnist',\n", + " split=['train', 'test'],\n", + " shuffle_files=True,\n", + " as_supervised=True,\n", + " with_info=True,\n", + " )\n", + "\n", + " def normalize_img(image, label):\n", + " return tf.cast(image, tf.float32) / 255., label\n", + "\n", + " ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n", + " ds_train = ds_train.cache().shuffle(ds_info.splits['train'].num_examples).batch(128).prefetch(tf.data.AUTOTUNE)\n", + " ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE).batch(128)\n", + "\n", + " return ds_train, ds_test\n" + ], + "metadata": { + "id": "fwtJHnflfv_f" + }, + "execution_count": 28, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Creating a Fully-Connected Model\n", + "\n", + "In this tutorial section, we create a simple toy example of a fully connected model to demonstrate the pruning process using MCT. It consists of three dense layers with 128, 64, and 10 neurons.\n", + "\n", + "Notably, MCT's structured pruning will target the first two dense layers for pruning, as these layers offer the opportunity to reduce output channels. This reduction can be effectively propagated by adjusting the input channels of subsequent layers.\n", + "\n", + "Once our model is created, we compile it to prepare the model for training and evaluation.\n" + ], + "metadata": { + "id": "m3vu7-uvgtfC" + } + }, + { + "cell_type": "code", + "source": [ + "def create_model():\n", + " model = tf.keras.models.Sequential([\n", + " tf.keras.layers.Flatten(input_shape=(28, 28)),\n", + " tf.keras.layers.Dense(128, activation='relu'),\n", + " tf.keras.layers.Dense(64, activation='relu'),\n", + " tf.keras.layers.Dense(10)\n", + " ])\n", + " model.compile(\n", + " optimizer=tf.keras.optimizers.Adam(0.001),\n", + " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n", + " )\n", + " return model" + ], + "metadata": { + "id": "If3oj5jSjXen" + }, + "execution_count": 29, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Training Dense Model on MNIST\n", + "\n", + "Now, we can train our model using the dataset we load and evaluate it." + ], + "metadata": { + "id": "Q_tK6Xknbtha" + } + }, + { + "cell_type": "code", + "source": [ + "# Load MNIST dataset\n", + "ds_train, ds_test = load_and_preprocess_mnist()\n", + "\n", + "# Train and evaluate the model\n", + "model = create_model()\n", + "model.fit(ds_train, epochs=6, validation_data=ds_test)\n", + "model.evaluate(ds_test)" + ], + "metadata": { + "id": "jQ3_9Z1WllVV" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Dense Model Properties\n", + "\n", + "The model.summary() function in Keras provides a snapshot of the model's architecture, including layers, their types, output shapes, and the number of parameters.\n" + ], + "metadata": { + "id": "ZQHxLrsvcLKH" + } + }, + { + "cell_type": "code", + "source": [ + "model.summary()" + ], + "metadata": { + "id": "oxdespw2eeBW" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Let's break down what we see in our model summary:\n", + "\n", + "- First Dense Layer: A fully connected layer with 128 output channels and 784 input channels.\n", + "\n", + "- Second Dense Layer: A fully connected layer with 64 output channels and 128 input channels.\n", + "\n", + "- Third Dense Layer: The final dense layer with 10 neurons (as per the number of MNIST classes) and 64 input channels.\n", + "\n", + "The total parameters amount to 109,386, which roughly requiers 427.29 KB." + ], + "metadata": { + "id": "GymibwxQehOL" + } + }, + { + "cell_type": "markdown", + "source": [ + "## MCT Structured Pruning\n", + "\n", + "### Create TPC\n", + "\n", + "Firstly, we'll set up the Target Platform Capabilities (TPC) to specify each layer's SIMD (Single Instruction, Multiple Data) size.\n", + "\n", + "In MCT, SIMD plays a crucial role in channel grouping, affecting the pruning decision process based on channel importance for each SIMD group of channels.\n", + "\n", + "We'll use the simplest structured pruning scenario for this demonstration with SIMD=1." + ], + "metadata": { + "id": "RKatTp55emtF" + } + }, + { + "cell_type": "code", + "source": [ + "simd_size = 1\n", + "\n", + "def get_tpc():\n", + " tp = mct.target_platform\n", + " default_config = tp.OpQuantizationConfig(\n", + " simd_size=simd_size,\n", + " activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,\n", + " weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,\n", + " activation_n_bits=None,\n", + " weights_n_bits=None,\n", + " weights_per_channel_threshold=None,\n", + " enable_weights_quantization=None,\n", + " enable_activation_quantization=None,\n", + " quantization_preserving=None,\n", + " fixed_scale=None,\n", + " fixed_zero_point=None,\n", + " weights_multiplier_nbits=None)\n", + "\n", + " default_configuration_options = tp.QuantizationConfigOptions([default_config])\n", + " tp_model = tp.TargetPlatformModel(default_configuration_options)\n", + " tpc = tp.TargetPlatformCapabilities(tp_model)\n", + " return tpc\n", + "\n" + ], + "metadata": { + "id": "wqZ71s70jXhH" + }, + "execution_count": 32, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Create a Representative Dataset\n", + "\n", + "We are creating a representative dataset to guide our model pruning process for computing importance score for each channel:" + ], + "metadata": { + "id": "SnKxedEgqdSm" + } + }, + { + "cell_type": "code", + "source": [ + "# Create a representative dataset\n", + "ds_train_as_iter = iter(ds_train)\n", + "\n", + "def representative_data_gen() -> list:\n", + " yield [next(ds_train_as_iter)[0]]" + ], + "metadata": { + "id": "SCiXV1s9jswp" + }, + "execution_count": 33, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Create KPI\n", + "\n", + "We're defining a Key Performance Indicator (KPI) to constrain the memory usage of our pruned model.\n", + "\n", + "By setting a target that limits the model's weight memory to half of its original size (around 427KB), we aim to achieve a compression ratio of 50%:" + ], + "metadata": { + "id": "nylQtALnr9gN" + } + }, + { + "cell_type": "code", + "source": [ + "# Create KPI to limit the pruned model weights memory to a certain KPI\n", + "dense_model_memory = 427*(2**10) # Original model weights requiers ~427KB\n", + "compression_ratio = 0.5\n", + "\n", + "kpi = mct.KPI(weights_memory=dense_model_memory*compression_ratio)" + ], + "metadata": { + "id": "doJgwbSxsCbr" + }, + "execution_count": 34, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Prune Model\n", + "\n", + "We're ready to execute the actual pruning using MCT's keras_pruning_experimental function. The model is pruned according to our defined KPI and using the representative dataset generated earlier.\n", + "\n", + "Each channel's importance is measured using LFH (Label-Free-Hessian)\n", + "which approximates the Hessian of the loss function w.r.t model's weights.\n", + "\n", + "In this example, we've used just one score approximation for efficiency. Although this is less time-consuming, it's worth noting that using multiple approximations would yield more precise importance scores in real-world applications. However, this precision comes with a trade-off in terms of longer processing times.\n", + "\n", + "The result is a pruned model and associated pruning information, which includes details about the pruning masks and scores for each layer." + ], + "metadata": { + "id": "xSP6815rsCnc" + } + }, + { + "cell_type": "code", + "source": [ + "num_score_approximations = 1\n", + "\n", + "target_platform_cap = get_tpc()\n", + "pruned_model, pruning_info = mct.pruning.keras_pruning_experimental(\n", + " model=model,\n", + " target_kpi=kpi,\n", + " representative_data_gen=representative_data_gen,\n", + " target_platform_capabilities=target_platform_cap,\n", + " pruning_config=mct.pruning.PruningConfig(num_score_approximations=num_score_approximations)\n", + " )" + ], + "metadata": { + "id": "x4taG-5TxBrp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Pruned Model Properties\n", + "\n", + "As before, we can use Keras model's API to observe the new architecture and details of the pruned model:" + ], + "metadata": { + "id": "iPd6ezZN2DNp" + } + }, + { + "cell_type": "code", + "source": [ + "pruned_model.summary()" + ], + "metadata": { + "id": "xZu4gPwz2Ptp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Retraining Pruned Model\n", + "\n", + "After pruning models, it's common to observe a temporary drop in the model's accuracy. This decline directly results from reducing the model's complexity through pruning." + ], + "metadata": { + "id": "pAheQ9SGxB13" + } + }, + { + "cell_type": "code", + "source": [ + "pruned_model.compile(\n", + " optimizer=tf.keras.optimizers.Adam(0.001),\n", + " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n", + ")\n", + "pruned_model.evaluate(ds_test)" + ], + "metadata": { + "id": "Vpihq5fpdeSA" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "However, to recover the performance, we retrain the pruned model, allowing it to adapt to its new, compressed architecture. The model can regain, and sometimes even surpass, its original accuracy through retraining." + ], + "metadata": { + "id": "IHORL34t17bA" + } + }, + { + "cell_type": "code", + "source": [ + "pruned_model.fit(ds_train, epochs=6, validation_data=ds_test)\n", + "pruned_model.evaluate(ds_test)" + ], + "metadata": { + "id": "q00zV9Jmjszo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "bb7e1572", + "metadata": { + "id": "bb7e1572" + }, + "source": [ + "Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " http://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License.\n" + ] + } + + ] +}