Skip to content

Commit

Permalink
Add TPC.v4 with quantization preservation. (sony#1214)
Browse files Browse the repository at this point in the history
* Add TPC.v4 with quantization preservation.
Handle quantization preservation in quantization options filtering.
  • Loading branch information
elad-c authored Sep 15, 2024
1 parent 953fa41 commit 4dac6a3
Show file tree
Hide file tree
Showing 10 changed files with 605 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,72 @@ def set_quantization_configuration_to_graph(graph: Graph,
return graph


def filter_node_qco_by_graph(node: BaseNode,
tpc: TargetPlatformCapabilities,
graph: Graph,
node_qc_options: QuantizationConfigOptions
) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
"""
Filter quantization config options that don't match the graph.
A node may have several quantization config options with 'activation_n_bits' values, and
the next nodes in the graph may support different bit-width as input activation. This function
filters out quantization config that don't comply to these attributes.
Args:
node: Node for filtering.
tpc: TPC to extract the QuantizationConfigOptions for the next nodes.
graph: Graph object.
node_qc_options: Node's QuantizationConfigOptions.
Returns:
A base config (OpQuantizationConfig) and a config options list (list of OpQuantizationConfig)
that are compatible with next nodes supported input bit-widths.
"""
# Filter quantization config options that don't match the graph.
_base_config = node_qc_options.base_config
_node_qc_options = node_qc_options.quantization_config_list

# Build next_nodes list by appending to the node's next nodes list all nodes that are quantization preserving.
_next_nodes = graph.get_next_nodes(node)
next_nodes = []
while len(_next_nodes):
n = _next_nodes.pop(0)
qco = n.get_qco(tpc)
qp = [qc.quantization_preserving for qc in qco.quantization_config_list]
if not all(qp) and any(qp):
Logger.error(f'Attribute "quantization_preserving" should be the same for all QuantizaionConfigOptions in {n}.')
if qp[0]:
_next_nodes.extend(graph.get_next_nodes(n))
next_nodes.append(n)

if len(next_nodes):
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])

# Filter node's QC options that match next nodes input bit-width.
_node_qc_options = [_option for _option in _node_qc_options
if _option.activation_n_bits <= next_nodes_supported_input_bitwidth]
if len(_node_qc_options) == 0:
Logger.critical(f"Graph doesn't match TPC bit configurations: {node} -> {next_nodes}.")

# Verify base config match
if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
if len(_node_qc_options) > 0:
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
Logger.warning(f"Node {node} base quantization config changed to match Graph and TPC configuration.\nCause: {node} -> {next_nodes}.")
else:
Logger.critical(f"Graph doesn't match TPC bit configurations: {node} -> {next_nodes}.") # pragma: no cover

return _base_config, _node_qc_options


def set_quantization_configs_to_node(node: BaseNode,
graph: Graph,
quant_config: QuantizationConfig,
Expand All @@ -99,7 +165,7 @@ def set_quantization_configs_to_node(node: BaseNode,
manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None.
"""
node_qc_options = node.get_qco(tpc)
base_config, node_qc_options_list = node.filter_node_qco_by_graph(tpc, graph.get_next_nodes(node), node_qc_options)
base_config, node_qc_options_list = filter_node_qco_by_graph(node, tpc, graph, node_qc_options)

# If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override,
# and update base_config accordingly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def get_tpc_dict_by_fw(fw_name):
get_keras_tpc as get_keras_tpc_v3
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_keras import \
get_keras_tpc as get_keras_tpc_v3_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_keras import \
get_keras_tpc as get_keras_tpc_v4

# Keras: TPC versioning
tpc_models_dict = {'v1': get_keras_tpc_v1,
Expand All @@ -51,6 +53,7 @@ def get_tpc_dict_by_fw(fw_name):
'v2_lut': get_keras_tpc_v2_lut,
'v3': get_keras_tpc_v3,
'v3_lut': get_keras_tpc_v3_lut,
'v4': get_keras_tpc_v4,
LATEST: get_keras_tpc_latest}
elif fw_name == PYTORCH:
###############################
Expand All @@ -73,6 +76,8 @@ def get_tpc_dict_by_fw(fw_name):
get_pytorch_tpc as get_pytorch_tpc_v3
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v3_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v4

# Pytorch: TPC versioning
tpc_models_dict = {'v1': get_pytorch_tpc_v1,
Expand All @@ -82,6 +87,7 @@ def get_tpc_dict_by_fw(fw_name):
'v2_lut': get_pytorch_tpc_v2_lut,
'v3': get_pytorch_tpc_v3,
'v3_lut': get_pytorch_tpc_v3_lut,
'v4': get_pytorch_tpc_v4,
LATEST: get_pytorch_tpc_latest}
if tpc_models_dict is not None:
return tpc_models_dict
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

__version__ = 'v4'
Loading

0 comments on commit 4dac6a3

Please sign in to comment.