Skip to content

Commit

Permalink
Add TPCInfo for quickstart (sony#950)
Browse files Browse the repository at this point in the history
Add TPCInfo with info needed in the quickstart and use it in quickstart.

Co-authored-by: reuvenp <reuvenp@altair-semi.com>
  • Loading branch information
reuvenperetz and reuvenp authored Feb 19, 2024
1 parent 9d2bd57 commit d643a4b
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 8 deletions.
11 changes: 6 additions & 5 deletions tutorials/quick_start/common/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from common.constants import MODEL_NAME, MODEL_LIBRARY, VALIDATION_DATASET_FOLDER

from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from tutorials.quick_start.common.tpc_info import TPCInfo


class DatasetInfo:
Expand All @@ -44,7 +45,7 @@ class QuantInfo:
Holds information about the quantization process.
"""
def __init__(self, user_info: UserInformation,
tpc_info: dict,
tpc_info: TPCInfo,
quantization_workflow: str,
mp_weights_compression: float = None
):
Expand All @@ -53,7 +54,7 @@ def __init__(self, user_info: UserInformation,
Args:
user_info (UserInformation): Quantization information returned from MCT
tpc_info (dict): The target platform capabilities information which is provided to the MCT.
tpc_info (TPCInfo): The target platform capabilities information which is provided to the MCT.
quantization_workflow (str): String to describe the quantization workflow (PTQ, GPTQ etc.).
mp_weights_compression (float): Weights compression factor for mixed precision KPI
"""
Expand Down Expand Up @@ -110,8 +111,8 @@ def parse_results(params: dict, float_acc: float, quant_acc: float, quant_info:
A dictionary containing the parsed results.
"""
a_bits = quant_info.tpc_info['Target Platform Model']['Default quantization config']['activation_n_bits']
w_bits = quant_info.tpc_info['Target Platform Model']['Default quantization config']['weights_n_bits']
a_bits = quant_info.tpc_info.activation_nbits
w_bits = quant_info.tpc_info.weights_nbits
bit_config = f'W{w_bits}A{a_bits}'
if quant_info.mp_weights_compression:
bit_config = f'{bit_config},MP-x{quant_info.mp_weights_compression}'
Expand All @@ -126,7 +127,7 @@ def parse_results(params: dict, float_acc: float, quant_acc: float, quant_info:
res['Size[MB]'] = round(quant_info.user_info.final_kpi.weights_memory / 1e6, 2)
res['BitsConfig'] = bit_config
res['QuantWorkflow'] = quant_info.quantization_workflow
res['TPC'] = quant_info.tpc_info['Target Platform Capabilities'] + '-' + quant_info.tpc_info['Version']
res['TPC'] = quant_info.tpc_info.tp_model_name + '-' + quant_info.tpc_info.version

return res

Expand Down
53 changes: 53 additions & 0 deletions tutorials/quick_start/common/tpc_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


def get_tpc_info(tpc):
# Retrieve the number of bits used for activation functions within the TPC's default operation.
activation_nbits = tpc.get_default_op_qc().activation_n_bits

# Retrieve the number of bits used for weights in the TPC's default weight configuration.
weights_nbits = tpc.get_default_op_qc().default_weight_attr_config.weights_n_bits

# Extract the name of the tp model associated with the TPC.
tp_model_name = tpc.tp_model.name

# Get the version of the TPC.
version = tpc.version

return TPCInfo(activation_nbits=activation_nbits,
weights_nbits=weights_nbits,
tp_model_name=tp_model_name,
version=version)

class TPCInfo:
def __init__(self,
activation_nbits: int,
weights_nbits: int,
tp_model_name: str,
version: str):
"""
Args:
activation_nbits: Number of bits used for activation functions.
weights_nbits: Number of bits used for weights.
tp_model_name: TP model's name.
version: TPC's version.
"""
self.activation_nbits = activation_nbits
self.weights_nbits = weights_nbits
self.tp_model_name = tp_model_name
self.version = version


6 changes: 5 additions & 1 deletion tutorials/quick_start/keras_fw/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from model_compression_toolkit.core import MixedPrecisionQuantizationConfigV2, CoreConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from tutorials.quick_start.common.results import QuantInfo
from tutorials.quick_start.common.tpc_info import get_tpc_info


def get_tpc(target_platform_name: str, target_platform_version: str) -> TargetPlatformCapabilities:
Expand Down Expand Up @@ -133,4 +134,7 @@ def quantize(model: tf.keras.Model,
core_config=core_conf,
target_platform_capabilities=tpc)

return quantized_model, QuantInfo(user_info=quantization_info, tpc_info=tpc.get_info(), quantization_workflow=workflow, mp_weights_compression=mp_wcr)
return quantized_model, QuantInfo(user_info=quantization_info,
tpc_info=get_tpc_info(tpc=tpc),
quantization_workflow=workflow,
mp_weights_compression=mp_wcr)
7 changes: 5 additions & 2 deletions tutorials/quick_start/pytorch_fw/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from tutorials.quick_start.common.constants import BYTES_TO_FP32, MP_WEIGHTS_COMPRESSION
from tutorials.quick_start.common.results import QuantInfo
from tutorials.quick_start.common.tpc_info import get_tpc_info


def get_tpc(target_platform_name: str, target_platform_version: str) -> TargetPlatformCapabilities:
Expand Down Expand Up @@ -143,5 +144,7 @@ def quantize(model: nn.Module,
save_model_path=onnx_file_path,
repr_dataset=representative_data_gen)


return quantized_model, QuantInfo(user_info=quantization_info, tpc_info=tpc.get_info(), quantization_workflow=workflow, mp_weights_compression=mp_wcr)
return quantized_model, QuantInfo(user_info=quantization_info,
tpc_info=get_tpc_info(tpc=tpc),
quantization_workflow=workflow,
mp_weights_compression=mp_wcr)

0 comments on commit d643a4b

Please sign in to comment.