diff --git a/model_compression_toolkit/core/graph_prep_runner.py b/model_compression_toolkit/core/graph_prep_runner.py index 79782d024..4d405b727 100644 --- a/model_compression_toolkit/core/graph_prep_runner.py +++ b/model_compression_toolkit/core/graph_prep_runner.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved. +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -41,23 +41,24 @@ def graph_preparation_runner(in_model: Any, tb_w: TensorboardWriter = None, mixed_precision_enable: bool = False) -> Graph: """ - Quantize a trained model using post-training quantization. - First, the model graph is optimized using several transformations (e.g. folding BatchNormalization to preceding - layers). - Second, statistics (e.g. min/max, histogram, etc.) are collected for each layer's output - (and input, depends on the quantization configuration) using a given representative dataset. - Next, quantization parameters are calculated using the collected statistics - (both coefficients and activations by default). + Runs all required preparations in order to build a quantization graph from the given model, + quantization configuration and target platform specifications. + This runner include the following steps: + - Reading and building a graph from the given model. + - Setting quantization config to each relevant node in the graph. + - Apply all necessary substitutions to finalize the graph for quantization. + Args: in_model: Model to quantize. representative_data_gen: Dataset used for calibration. - core_config: CoreConfig containing parameters of how the model should be quantized + quantization_config: QuantizationConfig containing parameters of how the model should be quantized. fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, - groups of layers by how they should be quantized, etc.). + groups of layers by how they should be quantized, etc.). fw_impl: FrameworkImplementation object with a specific framework methods implementation. tpc: TargetPlatformCapabilities object that models the inference target platform and - the attached framework operator's information. + the attached framework operator's information. tb_w: TensorboardWriter object for logging + Returns: An internal graph representation of the input model. """ @@ -92,16 +93,18 @@ def get_finalized_graph(initial_graph: Graph, """ Applies all edit operation (edit, substitutions, etc.) on the model's graph, to prepare it for the quantization process. All future graph substitutions and operations that change the graph should be added to this method. + Args: initial_graph (Graph): Graph to apply the changes to. tpc (TargetPlatformCapabilities): TargetPlatformCapabilities object that describes the desired inference target platform (includes fusing patterns MCT should handle). quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be - quantized. + quantized. fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., - kernel channels indices, groups of layers by how they should be quantized, etc.) + kernel channels indices, groups of layers by how they should be quantized, etc.) tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc. fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation. - mixed_precision_enable: is mixed precision enabled. + mixed_precision_enable: is mixed precision enabled. + Returns: Graph object that represents the model, after applying all required modifications to it. """ @@ -173,6 +176,7 @@ def read_model_to_graph(in_model: Any, """ Read a model into a graph object. + Args: in_model: Model to optimize and prepare for quantization. representative_data_gen: Dataset used for calibration. @@ -181,6 +185,7 @@ def read_model_to_graph(in_model: Any, fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.) fw_impl: FrameworkImplementation object with a specific framework methods implementation. + Returns: Graph object that represents the model. """ diff --git a/model_compression_toolkit/core/quantization_prep_runner.py b/model_compression_toolkit/core/quantization_prep_runner.py new file mode 100644 index 000000000..1f7729227 --- /dev/null +++ b/model_compression_toolkit/core/quantization_prep_runner.py @@ -0,0 +1,134 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +from typing import Callable + +from tqdm import tqdm + +from model_compression_toolkit.core.common import FrameworkInfo +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.graph.base_graph import Graph +from model_compression_toolkit.core.common.model_collector import ModelCollector +from model_compression_toolkit.core.common.network_editors.edit_network import edit_network_graph +from model_compression_toolkit.core.common.quantization.core_config import CoreConfig +from model_compression_toolkit.core.common.quantization.quantization_analyzer import analyzer_graph +from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \ + calculate_quantization_params +from model_compression_toolkit.core.common.statistics_correction.statistics_correction import \ + statistics_correction_runner +from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute + +from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter + + +def quantization_preparation_runner(graph: Graph, + representative_data_gen: Callable, + core_config: CoreConfig, + fw_info: FrameworkInfo, + fw_impl: FrameworkImplementation, + tb_w: TensorboardWriter = None) -> Graph: + """ + Prepares a trained model for post-training quantization. + First, the model graph is optimized using several transformations (e.g. folding BatchNormalization to preceding layers). + Second, statistics (e.g. min/max, histogram, etc.) are collected for each layer's output + (and input, depends on the quantization configuration) using a given representative dataset. + Next, quantization parameters are calculated using the collected statistics. + Finally, more transformations (based on the statistics) are applied to increase the model's performance. + + Args: + graph: A graph representation of the model to be quantized. + representative_data_gen: Dataset used for calibration. + core_config: CoreConfig containing parameters of how the model should be quantized + fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, + groups of layers by how they should be quantized, etc.). + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + tb_w: TensorboardWriter object for logging + + Returns: + Graph object that represents the model, contains thresholds, and ready for quantization. + """ + + ###################################### + # Graph analyzing (attaching statistics collectors) + ###################################### + analyzer_graph(fw_impl.attach_sc_to_node, + graph, + fw_info, + core_config.quantization_config) # Mark points for statistics collection + + if tb_w is not None: + tb_w.add_graph(graph, 'after_analyzer_graph') + + ###################################### + # Statistic collection + ###################################### + mi = ModelCollector(graph, + fw_impl, + fw_info) + + for _data in tqdm(representative_data_gen()): + mi.infer(_data) + + ###################################### + # Edit network according to user + # specific settings + ###################################### + # Notice that not all actions affect at this stage (for example, actions that edit the final configuration as + # there are no final configurations at this stage of the optimization). For this reason we edit the graph + # again at the end of the optimization process. + edit_network_graph(graph, fw_info, core_config.debug_config.network_editor) + + ###################################### + # Calculate quantization params + ###################################### + calculate_quantization_params(graph, + fw_info, + fw_impl=fw_impl) + + if tb_w is not None: + tb_w.add_graph(graph, 'thresholds_selection') + tb_w.add_all_statistics(graph, 'thresholds_selection') + + ###################################### + # Graph substitution (post statistics collection) + ###################################### + transformed_graph = substitute(graph, + fw_impl.get_substitutions_post_statistics_collection(core_config.quantization_config)) + + ###################################### + # Shift Negative Activations + ###################################### + if core_config.quantization_config.shift_negative_activation_correction: + transformed_graph = fw_impl.shift_negative_correction(transformed_graph, + core_config, + fw_info) + if tb_w is not None: + tb_w.add_graph(transformed_graph, 'after_shift_negative_correction') + tb_w.add_all_statistics(transformed_graph, 'after_shift_negative_correction') + + if tb_w is not None: + tb_w.add_graph(transformed_graph, 'post_statistics_collection_substitutions') + tb_w.add_all_statistics(transformed_graph, 'post_statistics_collection_substitutions') + + ###################################### + # Statistics Correction + ###################################### + tg_with_bias = statistics_correction_runner(transformed_graph, core_config, fw_info, fw_impl, tb_w) + + for n in tg_with_bias.nodes: + assert n.final_weights_quantization_cfg is None + + return tg_with_bias \ No newline at end of file diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py index db707954d..be7305b20 100644 --- a/model_compression_toolkit/core/runner.py +++ b/model_compression_toolkit/core/runner.py @@ -23,6 +23,7 @@ from model_compression_toolkit.core.common import FrameworkInfo from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner +from model_compression_toolkit.core.quantization_prep_runner import quantization_preparation_runner from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.graph.base_graph import Graph @@ -47,6 +48,7 @@ ActivationFinalBitwidthConfigVisualizer from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter + def core_runner(in_model: Any, representative_data_gen: Callable, core_config: CoreConfig, @@ -94,12 +96,12 @@ def core_runner(in_model: Any, representative_dataset=representative_data_gen, fw_impl=fw_impl) - tg = _prepare_model_for_quantization(graph, - representative_data_gen, - core_config, - fw_info, - tb_w, - fw_impl) + tg = quantization_preparation_runner(graph=graph, + representative_data_gen=representative_data_gen, + core_config=core_config, + fw_info=fw_info, + fw_impl=fw_impl, + tb_w=tb_w) ###################################### # Finalize bit widths @@ -179,131 +181,6 @@ def _init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter: return tb_w -def read_model_to_graph(in_model: Any, - representative_data_gen: Callable, - tpc: TargetPlatformCapabilities, - fw_info: FrameworkInfo = None, - fw_impl: FrameworkImplementation = None) -> Graph: - - """ - Read a model into a graph object. - Args: - in_model: Model to optimize and prepare for quantization. - representative_data_gen: Dataset used for calibration. - tpc: TargetPlatformCapabilities object that models the inference target platform and - the attached framework operator's information. - fw_info: Information needed for quantization about the specific framework (e.g., - kernel channels indices, groups of layers by how they should be quantized, etc.) - fw_impl: FrameworkImplementation object with a specific framework methods implementation. - Returns: - Graph object that represents the model. - """ - graph = fw_impl.model_reader(in_model, - representative_data_gen) - graph.set_fw_info(fw_info) - graph.set_tpc(tpc) - return graph - - -def _prepare_model_for_quantization(transformed_graph: Graph, - representative_data_gen: Callable, - core_config: CoreConfig = CoreConfig(), - fw_info: FrameworkInfo = None, - tb_w: TensorboardWriter = None, - fw_impl: FrameworkImplementation = None) -> Graph: - """ - Prepare a trained model for post-training quantization. - First, the model graph is optimized using several transformations (e.g. folding BatchNormalization to preceding layers). - Second, statistics (e.g. min/max, histogram, etc.) are collected for each layer's output - (and input, depends on the quantization configuration) using a given representative dataset. - Next, quantization parameters are calculated using the collected statistics. - Finally, more transformations (based on the statistics) are applied to increase the model's performance. - - Args: - representative_data_gen (Callable): Dataset used for calibration. - core_config (CoreConfig): CoreConfig containing parameters of how the model should be quantized. - fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., - kernel channels indices, groups of layers by how they should be quantized, etc.) - tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc. - fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation. - - Returns: - Graph object that represents the model, contains thresholds, and ready for quantization. - """ - - ###################################### - # Graph analyzing (attaching statistics collectors) - ###################################### - analyzer_graph(fw_impl.attach_sc_to_node, - transformed_graph, - fw_info, - core_config.quantization_config) # Mark points for statistics collection - - if tb_w is not None: - tb_w.add_graph(transformed_graph, 'after_analyzer_graph') - - ###################################### - # Statistic collection - ###################################### - mi = ModelCollector(transformed_graph, - fw_impl, - fw_info) - - for _data in tqdm(representative_data_gen()): - mi.infer(_data) - - ###################################### - # Edit network according to user - # specific settings - ###################################### - # Notice that not all actions affect at this stage (for example, actions that edit the final configuration as - # there are no final configurations at this stage of the optimization). For this reason we edit the graph - # again at the end of the optimization process. - edit_network_graph(transformed_graph, fw_info, core_config.debug_config.network_editor) - - ###################################### - # Calculate quantization params - ###################################### - calculate_quantization_params(transformed_graph, - fw_info, - fw_impl=fw_impl) - - if tb_w is not None: - tb_w.add_graph(transformed_graph, 'thresholds_selection') - tb_w.add_all_statistics(transformed_graph, 'thresholds_selection') - - ###################################### - # Graph substitution (post statistics collection) - ###################################### - transformed_graph = substitute(transformed_graph, - fw_impl.get_substitutions_post_statistics_collection(core_config.quantization_config)) - - ###################################### - # Shift Negative Activations - ###################################### - if core_config.quantization_config.shift_negative_activation_correction: - transformed_graph = fw_impl.shift_negative_correction(transformed_graph, - core_config, - fw_info) - if tb_w is not None: - tb_w.add_graph(transformed_graph, 'after_shift_negative_correction') - tb_w.add_all_statistics(transformed_graph, 'after_shift_negative_correction') - - if tb_w is not None: - tb_w.add_graph(transformed_graph, 'post_statistics_collection_substitutions') - tb_w.add_all_statistics(transformed_graph, 'post_statistics_collection_substitutions') - - ###################################### - # Statistics Correction - ###################################### - tg_with_bias = statistics_correction_runner(transformed_graph, core_config, fw_info, fw_impl, tb_w) - - for n in tg_with_bias.nodes: - assert n.final_weights_quantization_cfg is None - - return tg_with_bias - - def _set_final_kpi(graph: Graph, final_bit_widths_config: List[int], kpi_functions_dict: Dict[KPITarget, Tuple[MpKpiMetric, MpKpiAggregation]], diff --git a/tests/common_tests/helpers/prep_graph_for_func_test.py b/tests/common_tests/helpers/prep_graph_for_func_test.py index e41aea691..7ed3c8f68 100644 --- a/tests/common_tests/helpers/prep_graph_for_func_test.py +++ b/tests/common_tests/helpers/prep_graph_for_func_test.py @@ -22,8 +22,8 @@ from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \ calculate_quantization_params from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner -from model_compression_toolkit.core.runner import _init_tensorboard_writer, \ - _prepare_model_for_quantization +from model_compression_toolkit.core.quantization_prep_runner import quantization_preparation_runner +from model_compression_toolkit.core.runner import _init_tensorboard_writer from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_tp_model, \ get_op_quantization_configs @@ -129,12 +129,12 @@ def _representative_data_gen(): tpc=tpc, mixed_precision_enable=core_config.mixed_precision_enable) - tg = _prepare_model_for_quantization(graph, + tg = quantization_preparation_runner(graph, _representative_data_gen, core_config, fw_info, - tb_w, - fw_impl) + fw_impl, + tb_w) ###################################### # Finalize bit widths