Skip to content

Commit

Permalink
Extract quantization preparation steps from core runner (#872)
Browse files Browse the repository at this point in the history
* Extract quantization preparation steps from the main runner to an external "quantization prep runner": graph analyzing, statistic collection, network editor, quantization parameters calculation, snc and stat correction.
In addition, minor documentation fixes to the graph preparation runner and removed duplicated functions.

---------

Co-authored-by: Ofir Gordon <Ofir.Gordon@altair-semi.com>
  • Loading branch information
ofirgo and Ofir Gordon authored Nov 30, 2023
1 parent e76110a commit a505a88
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 150 deletions.
33 changes: 19 additions & 14 deletions model_compression_toolkit/core/graph_prep_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,23 +41,24 @@ def graph_preparation_runner(in_model: Any,
tb_w: TensorboardWriter = None,
mixed_precision_enable: bool = False) -> Graph:
"""
Quantize a trained model using post-training quantization.
First, the model graph is optimized using several transformations (e.g. folding BatchNormalization to preceding
layers).
Second, statistics (e.g. min/max, histogram, etc.) are collected for each layer's output
(and input, depends on the quantization configuration) using a given representative dataset.
Next, quantization parameters are calculated using the collected statistics
(both coefficients and activations by default).
Runs all required preparations in order to build a quantization graph from the given model,
quantization configuration and target platform specifications.
This runner include the following steps:
- Reading and building a graph from the given model.
- Setting quantization config to each relevant node in the graph.
- Apply all necessary substitutions to finalize the graph for quantization.
Args:
in_model: Model to quantize.
representative_data_gen: Dataset used for calibration.
core_config: CoreConfig containing parameters of how the model should be quantized
quantization_config: QuantizationConfig containing parameters of how the model should be quantized.
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
groups of layers by how they should be quantized, etc.).
groups of layers by how they should be quantized, etc.).
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
tpc: TargetPlatformCapabilities object that models the inference target platform and
the attached framework operator's information.
the attached framework operator's information.
tb_w: TensorboardWriter object for logging
Returns:
An internal graph representation of the input model.
"""
Expand Down Expand Up @@ -92,16 +93,18 @@ def get_finalized_graph(initial_graph: Graph,
"""
Applies all edit operation (edit, substitutions, etc.) on the model's graph, to prepare it for the quantization
process. All future graph substitutions and operations that change the graph should be added to this method.
Args:
initial_graph (Graph): Graph to apply the changes to.
tpc (TargetPlatformCapabilities): TargetPlatformCapabilities object that describes the desired inference target platform (includes fusing patterns MCT should handle).
quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be
quantized.
quantized.
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g.,
kernel channels indices, groups of layers by how they should be quantized, etc.)
kernel channels indices, groups of layers by how they should be quantized, etc.)
tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
mixed_precision_enable: is mixed precision enabled.
mixed_precision_enable: is mixed precision enabled.
Returns: Graph object that represents the model, after applying all required modifications to it.
"""

Expand Down Expand Up @@ -173,6 +176,7 @@ def read_model_to_graph(in_model: Any,

"""
Read a model into a graph object.
Args:
in_model: Model to optimize and prepare for quantization.
representative_data_gen: Dataset used for calibration.
Expand All @@ -181,6 +185,7 @@ def read_model_to_graph(in_model: Any,
fw_info: Information needed for quantization about the specific framework (e.g.,
kernel channels indices, groups of layers by how they should be quantized, etc.)
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
Returns:
Graph object that represents the model.
"""
Expand Down
134 changes: 134 additions & 0 deletions model_compression_toolkit/core/quantization_prep_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from typing import Callable

from tqdm import tqdm

from model_compression_toolkit.core.common import FrameworkInfo
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.graph.base_graph import Graph
from model_compression_toolkit.core.common.model_collector import ModelCollector
from model_compression_toolkit.core.common.network_editors.edit_network import edit_network_graph
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
from model_compression_toolkit.core.common.quantization.quantization_analyzer import analyzer_graph
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \
calculate_quantization_params
from model_compression_toolkit.core.common.statistics_correction.statistics_correction import \
statistics_correction_runner
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute

from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter


def quantization_preparation_runner(graph: Graph,
representative_data_gen: Callable,
core_config: CoreConfig,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation,
tb_w: TensorboardWriter = None) -> Graph:
"""
Prepares a trained model for post-training quantization.
First, the model graph is optimized using several transformations (e.g. folding BatchNormalization to preceding layers).
Second, statistics (e.g. min/max, histogram, etc.) are collected for each layer's output
(and input, depends on the quantization configuration) using a given representative dataset.
Next, quantization parameters are calculated using the collected statistics.
Finally, more transformations (based on the statistics) are applied to increase the model's performance.
Args:
graph: A graph representation of the model to be quantized.
representative_data_gen: Dataset used for calibration.
core_config: CoreConfig containing parameters of how the model should be quantized
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
groups of layers by how they should be quantized, etc.).
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
tb_w: TensorboardWriter object for logging
Returns:
Graph object that represents the model, contains thresholds, and ready for quantization.
"""

######################################
# Graph analyzing (attaching statistics collectors)
######################################
analyzer_graph(fw_impl.attach_sc_to_node,
graph,
fw_info,
core_config.quantization_config) # Mark points for statistics collection

if tb_w is not None:
tb_w.add_graph(graph, 'after_analyzer_graph')

######################################
# Statistic collection
######################################
mi = ModelCollector(graph,
fw_impl,
fw_info)

for _data in tqdm(representative_data_gen()):
mi.infer(_data)

######################################
# Edit network according to user
# specific settings
######################################
# Notice that not all actions affect at this stage (for example, actions that edit the final configuration as
# there are no final configurations at this stage of the optimization). For this reason we edit the graph
# again at the end of the optimization process.
edit_network_graph(graph, fw_info, core_config.debug_config.network_editor)

######################################
# Calculate quantization params
######################################
calculate_quantization_params(graph,
fw_info,
fw_impl=fw_impl)

if tb_w is not None:
tb_w.add_graph(graph, 'thresholds_selection')
tb_w.add_all_statistics(graph, 'thresholds_selection')

######################################
# Graph substitution (post statistics collection)
######################################
transformed_graph = substitute(graph,
fw_impl.get_substitutions_post_statistics_collection(core_config.quantization_config))

######################################
# Shift Negative Activations
######################################
if core_config.quantization_config.shift_negative_activation_correction:
transformed_graph = fw_impl.shift_negative_correction(transformed_graph,
core_config,
fw_info)
if tb_w is not None:
tb_w.add_graph(transformed_graph, 'after_shift_negative_correction')
tb_w.add_all_statistics(transformed_graph, 'after_shift_negative_correction')

if tb_w is not None:
tb_w.add_graph(transformed_graph, 'post_statistics_collection_substitutions')
tb_w.add_all_statistics(transformed_graph, 'post_statistics_collection_substitutions')

######################################
# Statistics Correction
######################################
tg_with_bias = statistics_correction_runner(transformed_graph, core_config, fw_info, fw_impl, tb_w)

for n in tg_with_bias.nodes:
assert n.final_weights_quantization_cfg is None

return tg_with_bias
139 changes: 8 additions & 131 deletions model_compression_toolkit/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from model_compression_toolkit.core.common import FrameworkInfo
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
from model_compression_toolkit.core.quantization_prep_runner import quantization_preparation_runner
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.graph.base_graph import Graph
Expand All @@ -47,6 +48,7 @@
ActivationFinalBitwidthConfigVisualizer
from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter


def core_runner(in_model: Any,
representative_data_gen: Callable,
core_config: CoreConfig,
Expand Down Expand Up @@ -94,12 +96,12 @@ def core_runner(in_model: Any,
representative_dataset=representative_data_gen,
fw_impl=fw_impl)

tg = _prepare_model_for_quantization(graph,
representative_data_gen,
core_config,
fw_info,
tb_w,
fw_impl)
tg = quantization_preparation_runner(graph=graph,
representative_data_gen=representative_data_gen,
core_config=core_config,
fw_info=fw_info,
fw_impl=fw_impl,
tb_w=tb_w)

######################################
# Finalize bit widths
Expand Down Expand Up @@ -179,131 +181,6 @@ def _init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter:
return tb_w


def read_model_to_graph(in_model: Any,
representative_data_gen: Callable,
tpc: TargetPlatformCapabilities,
fw_info: FrameworkInfo = None,
fw_impl: FrameworkImplementation = None) -> Graph:

"""
Read a model into a graph object.
Args:
in_model: Model to optimize and prepare for quantization.
representative_data_gen: Dataset used for calibration.
tpc: TargetPlatformCapabilities object that models the inference target platform and
the attached framework operator's information.
fw_info: Information needed for quantization about the specific framework (e.g.,
kernel channels indices, groups of layers by how they should be quantized, etc.)
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
Returns:
Graph object that represents the model.
"""
graph = fw_impl.model_reader(in_model,
representative_data_gen)
graph.set_fw_info(fw_info)
graph.set_tpc(tpc)
return graph


def _prepare_model_for_quantization(transformed_graph: Graph,
representative_data_gen: Callable,
core_config: CoreConfig = CoreConfig(),
fw_info: FrameworkInfo = None,
tb_w: TensorboardWriter = None,
fw_impl: FrameworkImplementation = None) -> Graph:
"""
Prepare a trained model for post-training quantization.
First, the model graph is optimized using several transformations (e.g. folding BatchNormalization to preceding layers).
Second, statistics (e.g. min/max, histogram, etc.) are collected for each layer's output
(and input, depends on the quantization configuration) using a given representative dataset.
Next, quantization parameters are calculated using the collected statistics.
Finally, more transformations (based on the statistics) are applied to increase the model's performance.
Args:
representative_data_gen (Callable): Dataset used for calibration.
core_config (CoreConfig): CoreConfig containing parameters of how the model should be quantized.
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g.,
kernel channels indices, groups of layers by how they should be quantized, etc.)
tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
Returns:
Graph object that represents the model, contains thresholds, and ready for quantization.
"""

######################################
# Graph analyzing (attaching statistics collectors)
######################################
analyzer_graph(fw_impl.attach_sc_to_node,
transformed_graph,
fw_info,
core_config.quantization_config) # Mark points for statistics collection

if tb_w is not None:
tb_w.add_graph(transformed_graph, 'after_analyzer_graph')

######################################
# Statistic collection
######################################
mi = ModelCollector(transformed_graph,
fw_impl,
fw_info)

for _data in tqdm(representative_data_gen()):
mi.infer(_data)

######################################
# Edit network according to user
# specific settings
######################################
# Notice that not all actions affect at this stage (for example, actions that edit the final configuration as
# there are no final configurations at this stage of the optimization). For this reason we edit the graph
# again at the end of the optimization process.
edit_network_graph(transformed_graph, fw_info, core_config.debug_config.network_editor)

######################################
# Calculate quantization params
######################################
calculate_quantization_params(transformed_graph,
fw_info,
fw_impl=fw_impl)

if tb_w is not None:
tb_w.add_graph(transformed_graph, 'thresholds_selection')
tb_w.add_all_statistics(transformed_graph, 'thresholds_selection')

######################################
# Graph substitution (post statistics collection)
######################################
transformed_graph = substitute(transformed_graph,
fw_impl.get_substitutions_post_statistics_collection(core_config.quantization_config))

######################################
# Shift Negative Activations
######################################
if core_config.quantization_config.shift_negative_activation_correction:
transformed_graph = fw_impl.shift_negative_correction(transformed_graph,
core_config,
fw_info)
if tb_w is not None:
tb_w.add_graph(transformed_graph, 'after_shift_negative_correction')
tb_w.add_all_statistics(transformed_graph, 'after_shift_negative_correction')

if tb_w is not None:
tb_w.add_graph(transformed_graph, 'post_statistics_collection_substitutions')
tb_w.add_all_statistics(transformed_graph, 'post_statistics_collection_substitutions')

######################################
# Statistics Correction
######################################
tg_with_bias = statistics_correction_runner(transformed_graph, core_config, fw_info, fw_impl, tb_w)

for n in tg_with_bias.nodes:
assert n.final_weights_quantization_cfg is None

return tg_with_bias


def _set_final_kpi(graph: Graph,
final_bit_widths_config: List[int],
kpi_functions_dict: Dict[KPITarget, Tuple[MpKpiMetric, MpKpiAggregation]],
Expand Down
Loading

0 comments on commit a505a88

Please sign in to comment.