Skip to content

Commit

Permalink
Keras structured SIMD pruning (sony#871)
Browse files Browse the repository at this point in the history
This commit introduces structured and hardware-aware model pruning to MCT. It optimizes models for specific hardware architectures by considering their SIMD capabilities. This approach prunes groups of channels (SIMD groups) to efficiently reduce model size and complexity while aligning with the hardware's SIMD structure. Added functionalities include calculating pruning masks, handling SIMD group-based pruning, and updating model architecture according to pruning sections. Unit tests and example usage in a tutorial notebook are included for demonstration and validation.

---------

Co-authored-by: reuvenp <reuvenp@altair-semi.com>
  • Loading branch information
reuvenperetz and reuvenp authored Dec 28, 2023
1 parent 5f98016 commit 5306a8d
Show file tree
Hide file tree
Showing 53 changed files with 4,365 additions and 9 deletions.
1 change: 1 addition & 0 deletions .github/workflows/run_keras_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
# CPU environment (https://github.com/tensorflow/tensorflow/issues/41718).
# For this reason, if we run them in such an environment, we need to run them first non-parallel separately.
run: |
python -m unittest discover tests/keras_tests/pruning_tests -v
python -m unittest discover tests/keras_tests/non_parallel_tests -v
for script in tests/keras_tests/exporter_tests tests/keras_tests/feature_networks_tests tests/keras_tests/graph_tests tests/keras_tests/layer_tests; do python -m unittest discover $script -v & pids+=($!); done; for pid in ${pids[@]}; do wait $pid || exit 1; done
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,29 @@ In the following table we present the ImageNet validation results for these mode

For more results, please refer to [quick start](https://github.com/sony/model_optimization/tree/main/tutorials/quick_start).

### Structured Pruning
MCT introduces a structured and hardware-aware model pruning.
This pruning technique is designed to compress models for specific hardware architectures,
taking into account the target platform's Single Instruction, Multiple Data (SIMD) capabilities.
By pruning groups of channels (SIMD groups), our approach not only reduces model size
and complexity, but ensures that better utilization of channels is in line with the SIMD architecture
for a target KPI of weights memory footprint.


<u>_Note: Currently, only Keras models pruning is supported._</u>

#### Results

Results for applying pruning to reduce the parameters of the following models by 50%:

| Model | Dense Model Accuracy | Pruned Model Accuracy |
|-----------------|----------------------|-----------------------|
| ResNet50 [2] | 75.1 | 72.4 |
| DenseNet121 [2] | 75.0 | 71.15 |




## Contributions
MCT aims at keeping a more up-to-date fork and welcomes contributions from anyone.

Expand All @@ -153,7 +176,7 @@ MCT aims at keeping a more up-to-date fork and welcomes contributions from anyon

[1] Habi, H.V., Peretz, R., Cohen, E., Dikstein, L., Dror, O., Diamant, I., Jennings, R.H. and Netzer, A., 2021. [HPTQ: Hardware-Friendly Post Training Quantization. arXiv preprint](https://arxiv.org/abs/2109.09113).

[2] [MobilNet](https://keras.io/api/applications/mobilenet/#mobilenet-function) from Keras applications.
[2] [Keras Applications](https://keras.io/api/applications/)

[3] [TORCHVISION.MODELS](https://pytorch.org/vision/stable/models.html)

Expand Down
1 change: 1 addition & 0 deletions docsrc/source/api/experimental_api_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Functions
- :ref:`get_tensorflow_data_generation_config<ug-get_tensorflow_data_generation_config>`: A function to generate a DataGenerationConfig for Tensorflow data generation(experimental).
- :ref:`pytorch_data_generation_experimental<ug-pytorch_data_generation_experimental>`: A function to generate data for a Pytorch model (experimental).
- :ref:`get_pytorch_data_generation_config<ug-get_pytorch_data_generation_config>`: A function to load a DataGenerationConfig for Pytorch data generation (experimental).
- :ref:`keras_pruning_experimental<ug-keras_pruning_experimental>`: A function to apply structured pruning for Keras models (experimental).


Modules
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
:orphan:

.. _ug-keras_pruning_experimental:


================================================
Keras Structured Pruning
================================================

.. autofunction:: model_compression_toolkit.pruning.keras_pruning_experimental

================================================
Pruning Configuration
================================================

.. autofunction:: model_compression_toolkit.pruning.PruningConfig



================================================
Pruning Information
================================================

.. autofunction:: model_compression_toolkit.pruning.PruningInfo

1 change: 1 addition & 0 deletions docsrc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Keras:
* :ref:`Mixed-precision post training quantization<ug-keras_post_training_quantization_mixed_precision>`
* :ref:`Init model for Quantization Aware Training<ug-keras_quantization_aware_training_init>` (Experimental)
* :ref:`Finalize model after Quantization Aware Training<ug-keras_quantization_aware_training_finalize>` (Experimental)
* :ref:`Structured Pruning<ug-keras_pruning_experimental>` (Experimental)

Pytorch:

Expand Down
1 change: 1 addition & 0 deletions model_compression_toolkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from model_compression_toolkit import exporter
from model_compression_toolkit import gptq
from model_compression_toolkit import data_generation
from model_compression_toolkit import pruning
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model


Expand Down
4 changes: 4 additions & 0 deletions model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
MIN_THRESHOLD = (2 ** -16)
EPS = 1e-8
LUT_VALUES_BITWIDTH = 8
FP32_BYTES_PER_PARAMETER = 4.

# Quantization attributes:
OUTPUT_SCALE = 'output_scale'
Expand Down Expand Up @@ -127,3 +128,6 @@
HESSIAN_OUTPUT_ALPHA = 0.3
HESSIAN_NUM_ITERATIONS = 50
HESSIAN_EPS = 1e-6

# Pruning constants
PRUNING_NUM_SCORE_APPROXIMATIONS = 32
114 changes: 114 additions & 0 deletions model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
from model_compression_toolkit.core.common.collectors.statistics_collector import scale_statistics, shift_statistics
from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
from model_compression_toolkit.core.common.user_info import UserInformation
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
Expand Down Expand Up @@ -726,3 +727,116 @@ def replace_node(self, node_to_replace: BaseNode, new_node: BaseNode):
self.replace_output_node(node_to_replace, new_node)
self.replace_input_node(node_to_replace, new_node)
self.remove_node(node_to_replace)

def get_pruning_sections(self,
fw_impl: Any) -> List[PruningSection]:
"""
Constructs pruning sections for a given computational graph.
Each section is created starting from an entry node and includes intermediate and exit nodes.
Args:
fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning.
Returns: List of PruningSection in the graph.
"""
entry_nodes = self.get_pruning_sections_entry_nodes(fw_impl)
return [self._create_pruning_section(entry_node, fw_impl) for entry_node in entry_nodes]

def get_pruning_sections_entry_nodes(self, fw_impl: Any) -> List[BaseNode]:
"""
Identifies entry nodes for pruning sections within the graph.
Traverses the graph in a topological order, checking each node for prunability criteria.
Returns a list of nodes that mark the beginning of a prunable section in the graph.
Args:
fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning.
Returns: List of nodes that are entry nodes in the pruning sections of the graph.
"""
prunable_nodes = []
for n in list(topological_sort(self)):
if fw_impl.is_node_entry_node(n) and self._is_node_topology_prunable(n, fw_impl):
prunable_nodes.append(n)
return prunable_nodes

def _is_node_topology_prunable(self, entry_node: BaseNode, fw_impl: Any) -> bool:
"""
Determines if the topology starting from a given entry node is suitable for pruning.
Iteratively examines the graph structure, focusing on node connectivity and pruning criteria.
Returns True if the topology is prunable, False otherwise.
Args:
entry_node (BaseNode): The node to start the topology check from.
fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning.
Returns: Whether this node is a start of a pruning section according to the graph topology or not.
"""
next_node = entry_node

# Continue iterating until the conditions for prunability are no longer met
while len(self.out_edges(next_node)) == 1:
next_node = self.out_edges(next_node)[0].sink_node

# If next_node is an exit node and has only one incoming edge, the topology is prunable.
if fw_impl.is_node_exit_node(next_node, entry_node, self.fw_info) and len(self.in_edges(next_node)) == 1:
return True

# If the next node is not an intermediate node or has more than one incoming/outgoing edge,
# stop the check.
if not fw_impl.is_node_intermediate_pruning_section(next_node) or len(self.in_edges(next_node)) != 1 or len(self.out_edges(next_node)) != 1:
return False

# If the loop exits normally, it implies that the topology is not prunable
return False


def _create_pruning_section(self, entry_node: BaseNode, fw_impl: Any) -> PruningSection:
"""
Creates a PruningSection object starting from a given entry node.
Includes logic to find intermediate and exit nodes to complete the section.
Ensures the provided entry node is a valid starting point for pruning.
Args:
entry_node (BaseNode): The entry node to create the section it starts.
fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning.
Returns: The pruning section that starts with node entry_node.
"""
if not fw_impl.is_node_entry_node(entry_node):
Logger.error(f"Expected to find an entry node to create its pruning section,"
f"but node {entry_node} is not an entry node.")

intermediate_nodes, exit_node = self._find_intermediate_and_exit_nodes(entry_node, fw_impl)

if not fw_impl.is_node_exit_node(exit_node, entry_node, self.fw_info):
Logger.error(f"Expected to find exit node when creating a pruning section,"
f"but node {exit_node} is not an exit node.")

return PruningSection(entry_node=entry_node,
intermediate_nodes=intermediate_nodes,
exit_node=exit_node)

def _find_intermediate_and_exit_nodes(self, entry_node: BaseNode, fw_impl: Any) -> Tuple[List[BaseNode], BaseNode]:
"""
Identifies intermediate and exit nodes for a pruning section starting from an entry node.
Iterates through connected nodes to build the complete structure of the pruning section.
Args:
entry_node (BaseNode): An entry node to find the intermediate and exit nodes of its section.
fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning.
Returns: A tuple containing a list of intermediate nodes and the exit node.
"""
intermediate_nodes = []
next_node = self.out_edges(entry_node)[0].sink_node
while not fw_impl.is_node_exit_node(next_node, entry_node, self.fw_info):
intermediate_nodes.append(next_node)
next_node = self.out_edges(next_node)[0].sink_node

return intermediate_nodes, next_node


30 changes: 25 additions & 5 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np

from model_compression_toolkit.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE, \
ACTIVATION_NBITS_ATTRIBUTE
ACTIVATION_NBITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationConfigOptions, \
TargetPlatformCapabilities, LayerFilterParams
Expand Down Expand Up @@ -222,9 +222,9 @@ def get_memory_bytes(self, fw_info) -> float:
"""
q_params, f_params = self.get_num_parameters(fw_info)
if self.final_weights_quantization_cfg is None: # float coefficients
memory = (f_params+q_params) * 4
memory = (f_params+q_params) * FP32_BYTES_PER_PARAMETER
else:
memory = (f_params*4)+ (q_params * self.final_weights_quantization_cfg.weights_n_bits / 8) # in bytes
memory = (f_params * FP32_BYTES_PER_PARAMETER) + (q_params * self.final_weights_quantization_cfg.weights_n_bits / 8) # in bytes

return memory

Expand All @@ -239,7 +239,7 @@ def get_float_memory_bytes(self, fw_info) -> float:
"""
q_params, f_params = self.get_num_parameters(fw_info)
return (f_params + q_params) * 32 / 8 # in bytes
return (f_params + q_params) * FP32_BYTES_PER_PARAMETER

def get_unified_weights_candidates_dict(self):
"""
Expand Down Expand Up @@ -499,4 +499,24 @@ def is_match_filter_params(self, layer_filter_params: LayerFilterParams) -> bool
if not c.match(layer_config):
return False

return True
return True

def get_simd(self) -> int:
"""
Retrieves the SIMD size used for this node. It collects the SIMD sizes from all candidate
configurations and returns the minimum SIMD size.
Returns:
int: The node's SIMD size.
"""
simd_list = [qc.weights_quantization_cfg.simd_size for qc in self.candidates_quantization_cfg]
if len(simd_list) > 1:
Logger.warning(f"More than one pruning SIMD option is available."
f" Min SIMD is used: {min(simd_list)}")
if len(simd_list) == 0:
Logger.error(f"No SIMD option is available for {self}")
_simd = min(simd_list)
if _simd <= 0 or int(_simd) != _simd:
Logger.error(f"SIMD is expected to be a non-positive integer but found: {_simd}")
return _simd
16 changes: 16 additions & 0 deletions model_compression_toolkit/core/common/pruning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


Loading

0 comments on commit 5306a8d

Please sign in to comment.