diff --git a/tutorials/quick_start/common/results.py b/tutorials/quick_start/common/results.py index 25d576a93..142f58c46 100644 --- a/tutorials/quick_start/common/results.py +++ b/tutorials/quick_start/common/results.py @@ -21,6 +21,7 @@ from common.constants import MODEL_NAME, MODEL_LIBRARY, VALIDATION_DATASET_FOLDER from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities +from tutorials.quick_start.common.tpc_info import TPCInfo class DatasetInfo: @@ -44,7 +45,7 @@ class QuantInfo: Holds information about the quantization process. """ def __init__(self, user_info: UserInformation, - tpc_info: dict, + tpc_info: TPCInfo, quantization_workflow: str, mp_weights_compression: float = None ): @@ -53,7 +54,7 @@ def __init__(self, user_info: UserInformation, Args: user_info (UserInformation): Quantization information returned from MCT - tpc_info (dict): The target platform capabilities information which is provided to the MCT. + tpc_info (TPCInfo): The target platform capabilities information which is provided to the MCT. quantization_workflow (str): String to describe the quantization workflow (PTQ, GPTQ etc.). mp_weights_compression (float): Weights compression factor for mixed precision KPI """ @@ -110,8 +111,8 @@ def parse_results(params: dict, float_acc: float, quant_acc: float, quant_info: A dictionary containing the parsed results. """ - a_bits = quant_info.tpc_info['Target Platform Model']['Default quantization config']['activation_n_bits'] - w_bits = quant_info.tpc_info['Target Platform Model']['Default quantization config']['weights_n_bits'] + a_bits = quant_info.tpc_info.activation_nbits + w_bits = quant_info.tpc_info.weights_nbits bit_config = f'W{w_bits}A{a_bits}' if quant_info.mp_weights_compression: bit_config = f'{bit_config},MP-x{quant_info.mp_weights_compression}' @@ -126,7 +127,7 @@ def parse_results(params: dict, float_acc: float, quant_acc: float, quant_info: res['Size[MB]'] = round(quant_info.user_info.final_kpi.weights_memory / 1e6, 2) res['BitsConfig'] = bit_config res['QuantWorkflow'] = quant_info.quantization_workflow - res['TPC'] = quant_info.tpc_info['Target Platform Capabilities'] + '-' + quant_info.tpc_info['Version'] + res['TPC'] = quant_info.tpc_info.tp_model_name + '-' + quant_info.tpc_info.version return res diff --git a/tutorials/quick_start/common/tpc_info.py b/tutorials/quick_start/common/tpc_info.py new file mode 100644 index 000000000..4bdc9fa6f --- /dev/null +++ b/tutorials/quick_start/common/tpc_info.py @@ -0,0 +1,53 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +def get_tpc_info(tpc): + # Retrieve the number of bits used for activation functions within the TPC's default operation. + activation_nbits = tpc.get_default_op_qc().activation_n_bits + + # Retrieve the number of bits used for weights in the TPC's default weight configuration. + weights_nbits = tpc.get_default_op_qc().default_weight_attr_config.weights_n_bits + + # Extract the name of the tp model associated with the TPC. + tp_model_name = tpc.tp_model.name + + # Get the version of the TPC. + version = tpc.version + + return TPCInfo(activation_nbits=activation_nbits, + weights_nbits=weights_nbits, + tp_model_name=tp_model_name, + version=version) + +class TPCInfo: + def __init__(self, + activation_nbits: int, + weights_nbits: int, + tp_model_name: str, + version: str): + """ + Args: + activation_nbits: Number of bits used for activation functions. + weights_nbits: Number of bits used for weights. + tp_model_name: TP model's name. + version: TPC's version. + """ + self.activation_nbits = activation_nbits + self.weights_nbits = weights_nbits + self.tp_model_name = tp_model_name + self.version = version + + diff --git a/tutorials/quick_start/keras_fw/quant.py b/tutorials/quick_start/keras_fw/quant.py index 7195f1300..317cd38cb 100644 --- a/tutorials/quick_start/keras_fw/quant.py +++ b/tutorials/quick_start/keras_fw/quant.py @@ -26,6 +26,7 @@ from model_compression_toolkit.core import MixedPrecisionQuantizationConfigV2, CoreConfig from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities from tutorials.quick_start.common.results import QuantInfo +from tutorials.quick_start.common.tpc_info import get_tpc_info def get_tpc(target_platform_name: str, target_platform_version: str) -> TargetPlatformCapabilities: @@ -133,4 +134,7 @@ def quantize(model: tf.keras.Model, core_config=core_conf, target_platform_capabilities=tpc) - return quantized_model, QuantInfo(user_info=quantization_info, tpc_info=tpc.get_info(), quantization_workflow=workflow, mp_weights_compression=mp_wcr) \ No newline at end of file + return quantized_model, QuantInfo(user_info=quantization_info, + tpc_info=get_tpc_info(tpc=tpc), + quantization_workflow=workflow, + mp_weights_compression=mp_wcr) \ No newline at end of file diff --git a/tutorials/quick_start/pytorch_fw/quant.py b/tutorials/quick_start/pytorch_fw/quant.py index 1dd66b6e2..6631762c2 100644 --- a/tutorials/quick_start/pytorch_fw/quant.py +++ b/tutorials/quick_start/pytorch_fw/quant.py @@ -28,6 +28,7 @@ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities from tutorials.quick_start.common.constants import BYTES_TO_FP32, MP_WEIGHTS_COMPRESSION from tutorials.quick_start.common.results import QuantInfo +from tutorials.quick_start.common.tpc_info import get_tpc_info def get_tpc(target_platform_name: str, target_platform_version: str) -> TargetPlatformCapabilities: @@ -143,5 +144,7 @@ def quantize(model: nn.Module, save_model_path=onnx_file_path, repr_dataset=representative_data_gen) - - return quantized_model, QuantInfo(user_info=quantization_info, tpc_info=tpc.get_info(), quantization_workflow=workflow, mp_weights_compression=mp_wcr) \ No newline at end of file + return quantized_model, QuantInfo(user_info=quantization_info, + tpc_info=get_tpc_info(tpc=tpc), + quantization_workflow=workflow, + mp_weights_compression=mp_wcr) \ No newline at end of file