diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/activation_sampler.py b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/activation_sampler.py index e2b68c019b5..0d4084baea4 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/activation_sampler.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/activation_sampler.py @@ -40,13 +40,20 @@ from typing import Tuple, List, Dict, Union import numpy as np -from onnx import onnx_pb import onnxruntime as ort +import onnx from aimet_common.utils import AimetLogger from aimet_onnx.quantsim import QuantizationSimModel from aimet_onnx.utils import add_hook_to_get_activation, remove_activation_hooks, create_input_dict +from packaging import version +# pylint: disable=no-name-in-module, ungrouped-imports +if version.parse(onnx.__version__) >= version.parse("1.14.0"): + from onnx import ModelProto +else: + from onnx.onnx_pb import ModelProto + logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant) @@ -56,7 +63,7 @@ class ActivationSampler: collect the module's output and input activation data respectively """ def __init__(self, orig_op: str, quant_op: str, - orig_model: onnx_pb.ModelProto, quant_model: QuantizationSimModel, use_cuda: bool, + orig_model: ModelProto, quant_model: QuantizationSimModel, use_cuda: bool, device: int = 0): """ :param orig_op: Single un quantized op from the original session @@ -132,7 +139,7 @@ class ModuleData: Collect input and output data to and from module """ - def __init__(self, model: onnx_pb.ModelProto, node_name: str, providers: List): + def __init__(self, model: ModelProto, node_name: str, providers: List): """ :param session: ONNX session :param node: Module reference diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_optimizer.py b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_optimizer.py index c41126ce3fc..0d1edcbb455 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_optimizer.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_optimizer.py @@ -40,7 +40,7 @@ from typing import Union, Tuple, Dict import numpy as np import onnx -from onnx import onnx_pb, numpy_helper +from onnx import numpy_helper import torch import torch.nn.functional as functional from torch.utils.data import Dataset @@ -56,6 +56,13 @@ from aimet_torch.adaround.adaround_tensor_quantizer import AdaroundTensorQuantizer from aimet_torch.adaround.adaround_optimizer import AdaroundOptimizer as TorchAdaroundOptimizer +from packaging import version +# pylint: disable=no-name-in-module, ungrouped-imports +if version.parse(onnx.__version__) >= version.parse("1.14.0"): + from onnx import ModelProto +else: + from onnx.onnx_pb import ModelProto + logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant) BATCH_SIZE = 32 EMPIRICAL_THRESHOLD = 3 / 4 @@ -70,7 +77,7 @@ class AdaroundOptimizer: """ @classmethod def adaround_module(cls, module: ModuleInfo, quantized_input_name: str, - orig_model: onnx_pb.ModelProto, quant_model: QuantizationSimModel, + orig_model: ModelProto, quant_model: QuantizationSimModel, act_func: Union[torch.nn.Module, None], cached_dataset: Dataset, opt_params: AdaroundHyperParameters, param_to_adaround_tensor_quantizer: Dict, use_cuda: bool, device: int = 0): @@ -100,7 +107,7 @@ def adaround_module(cls, module: ModuleInfo, quantized_input_name: str, @classmethod def _optimize_rounding(cls, module: ModuleInfo, quantized_input_name, - orig_model: onnx_pb.ModelProto, quant_model: QuantizationSimModel, + orig_model: ModelProto, quant_model: QuantizationSimModel, act_func: Union[None, str], cached_dataset: Dataset, opt_params: AdaroundHyperParameters, param_to_adaround_tensor_quantizer: Dict, use_cuda: bool, device: int = 0): diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/utils.py b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/utils.py index 0cf1bb9fc35..e4365ad887f 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/utils.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/utils.py @@ -37,10 +37,17 @@ """ Utilities for Adaround ONNX """ from typing import Dict from collections import defaultdict -from onnx import onnx_pb +import onnx from aimet_onnx.meta.connectedgraph import ConnectedGraph +from packaging import version +# pylint: disable=no-name-in-module, ungrouped-imports +if version.parse(onnx.__version__) >= version.parse("1.14.0"): + from onnx import ModelProto +else: + from onnx.onnx_pb import ModelProto + class ModuleInfo: """ Class object containing information about a module """ def __init__(self): @@ -55,7 +62,7 @@ class ModelData: """ Class to collect data for each module of a class """ - def __init__(self, model: onnx_pb.ModelProto): + def __init__(self, model: ModelProto): """ :param model: ONNX Model """ diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/batch_norm_fold.py b/TrainingExtensions/onnx/src/python/aimet_onnx/batch_norm_fold.py index 7410cb9fd61..7aac0b24339 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/batch_norm_fold.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/batch_norm_fold.py @@ -38,8 +38,9 @@ from typing import Dict, List, Tuple import contextlib -from onnx import onnx_pb, numpy_helper import numpy as np +import onnx +from onnx import numpy_helper from aimet_common.bias_correction import ConvBnPatternHandler from aimet_common.graph_pattern_matcher import PatternType @@ -53,6 +54,13 @@ from aimet_onnx.meta.operations import Op from aimet_onnx.utils import get_node_attribute, remove_node, transpose_tensor, ParamUtils, retrieve_constant_input +from packaging import version +# pylint: disable=no-name-in-module, ungrouped-imports +if version.parse(onnx.__version__) >= version.parse("1.14.0"): + from onnx import NodeProto, TensorProto, ModelProto +else: + from onnx.onnx_pb import NodeProto, TensorProto, ModelProto + logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.BatchNormFolding) ConvType = ['Conv', 'ConvTranspose'] @@ -103,8 +111,8 @@ def _find_conv_bn_pairs(connected_graph: ConnectedGraph) -> Dict: def find_all_batch_norms_to_fold(connected_graph: ConnectedGraph, - ) -> Tuple[List[Tuple[onnx_pb.NodeProto, onnx_pb.NodeProto]], - List[Tuple[onnx_pb.NodeProto, onnx_pb.NodeProto]]]: + ) -> Tuple[List[Tuple[NodeProto, NodeProto]], + List[Tuple[NodeProto, NodeProto]]]: """ Find all possible batch norm layers that can be folded. Returns a list of pairs such that (bn, layer) means bn will be forward-folded into layer and (layer, bn) means bn will be backward-folded into layer @@ -164,7 +172,7 @@ def get_ordered_conv_linears(conn_graph: ConnectedGraph) -> List[Op]: return ordered_convs -def is_valid_bn_fold(conv_linear: onnx_pb.NodeProto, model: onnx_pb.ModelProto, fold_backward: bool) -> bool: +def is_valid_bn_fold(conv_linear: NodeProto, model: ModelProto, fold_backward: bool) -> bool: """ Determine if a given layer can successfully absorb a BatchNorm given the layer type and parameters :param conv_linear: The Conv/Linear layer to fold a BatchNorm into. @@ -193,7 +201,7 @@ def is_valid_bn_fold(conv_linear: onnx_pb.NodeProto, model: onnx_pb.ModelProto, return valid -def fold_all_batch_norms_to_weight(model: onnx_pb.ModelProto) -> [List]: +def fold_all_batch_norms_to_weight(model: ModelProto) -> [List]: """ Fold all possible batch_norm layers in a model into the weight of the corresponding conv layers @@ -218,9 +226,9 @@ def fold_all_batch_norms_to_weight(model: onnx_pb.ModelProto) -> [List]: return conv_bns, bn_convs -def _fold_to_weight(model: onnx_pb.ModelProto, - conv_linear: onnx_pb.NodeProto, - bn: onnx_pb.NodeProto, +def _fold_to_weight(model: ModelProto, + conv_linear: NodeProto, + bn: NodeProto, fold_backward: bool): """ Fold BatchNorm into the weight and bias of the given layer. @@ -273,7 +281,7 @@ def _fold_to_weight(model: onnx_pb.ModelProto, return bn_layer -def _matmul_to_gemm(node: onnx_pb.NodeProto, model: onnx_pb.ModelProto): +def _matmul_to_gemm(node: NodeProto, model: ModelProto): """ Convert MatMul node to Gemm and initialize bias to zeros @@ -298,8 +306,8 @@ def _matmul_to_gemm(node: onnx_pb.NodeProto, model: onnx_pb.ModelProto): node.input.append(bias_name) -def _call_mo_batch_norm_fold(weight: onnx_pb.TensorProto, - bias: onnx_pb.TensorProto, +def _call_mo_batch_norm_fold(weight: TensorProto, + bias: TensorProto, bn_params: libpymo.BNParams, fold_backward: bool): """ @@ -328,7 +336,7 @@ def _call_mo_batch_norm_fold(weight: onnx_pb.TensorProto, weight.raw_data = np.asarray(weight_tensor.data, dtype=np.float32).tobytes() -def get_bn_params(model: onnx_pb.ModelProto, bn: onnx_pb.NodeProto, channels: int) -> libpymo.BNParams: +def get_bn_params(model: ModelProto, bn: NodeProto, channels: int) -> libpymo.BNParams: """ Returns the populated libpymo.BNParams object for the given BatchNormalization layer with parameters repeated if necessary. @@ -355,7 +363,7 @@ def get_bn_params(model: onnx_pb.ModelProto, bn: onnx_pb.NodeProto, channels: in return bn_params -def copy_bn_params_to_bn_layer(bn: onnx_pb.NodeProto, bn_params: libpymo.BNParams) -> BNLayer: +def copy_bn_params_to_bn_layer(bn: NodeProto, bn_params: libpymo.BNParams) -> BNLayer: """ Copies bn params to a BN layer which can be used later by High bias absorption :param bn: BN layer @@ -390,7 +398,7 @@ def _expand_shape_to_4d(weight_tensor: libpymo.TensorParams): weight_tensor.shape = orig_shape -def get_input_output_channels(node: onnx_pb.NodeProto, model: onnx_pb.ModelProto) -> Tuple[int, int]: +def get_input_output_channels(node: NodeProto, model: ModelProto) -> Tuple[int, int]: """ Find the input and output channels of a given layer. :param node: The node to find the input/output channels of diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/cross_layer_equalization.py b/TrainingExtensions/onnx/src/python/aimet_onnx/cross_layer_equalization.py index 6074d11de1e..eecf7c1b7b3 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/cross_layer_equalization.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/cross_layer_equalization.py @@ -44,7 +44,8 @@ from typing import Tuple, List, Union import numpy as np -from onnx import onnx_pb, numpy_helper +import onnx +from onnx import numpy_helper from onnxruntime.quantization.onnx_quantizer import ONNXModel from aimet_common.utils import AimetLogger @@ -58,6 +59,13 @@ from aimet_onnx.utils import transpose_tensor, ParamUtils, get_node_attribute, replace_relu6_with_relu from aimet_onnx.batch_norm_fold import BNLayer, fold_all_batch_norms_to_weight +from packaging import version +# pylint: disable=no-name-in-module, ungrouped-imports +if version.parse(onnx.__version__) >= version.parse("1.14.0"): + from onnx import NodeProto, ModelProto +else: + from onnx.onnx_pb import NodeProto, ModelProto + logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant) ClsSet = Union[Tuple['Conv', 'Conv'], @@ -82,7 +90,7 @@ class CrossLayerScaling(CLS): """ Scales a model's layers to equalize the weights between consecutive layers """ - def __init__(self, model: onnx_pb.ModelProto): + def __init__(self, model: ModelProto): """ :param model: ONNX model """ @@ -121,7 +129,7 @@ def scale_model(self) -> List[ClsSetInfo]: return cls_set_info_list - def _populate_libpymo_params(self, module: onnx_pb.NodeProto, + def _populate_libpymo_params(self, module: NodeProto, layer_param: libpymo.EqualizationParams): """ Populates libpymo weight parameter @@ -158,7 +166,7 @@ def _pack_params_for_conv(self, prev_layer_params.isBiasNone = True def _update_weight_for_layer_from_libpymo_obj(self, layer_param: libpymo.EqualizationParams, - module: onnx_pb.NodeProto): + module: NodeProto): """ Update weight parameter from libpymo object """ @@ -257,7 +265,7 @@ class HighBiasFold(HBF): """ Code to apply the high-bias-fold technique to a model """ - def __init__(self, model: onnx_pb.ModelProto): + def __init__(self, model: ModelProto): self._model = model def _check_if_bias_is_none(self, layer: Op) -> bool: @@ -307,7 +315,7 @@ def _pack_previous_and_current_layer_params(self, cls_pair_info: ClsSetInfo.ClsS curr_layer_params.weightShape = get_weight_dimensions(np.array(weight.dims)) def _update_bias_for_layer_from_libpymo_obj(self, layer_param: libpymo.LayerParams, - module: onnx_pb.NodeProto): + module: NodeProto): """ Update bias parameter from libpymo object """ @@ -340,7 +348,7 @@ def get_weight_dimensions(weight_shape: np.array) -> np.array: return np.append(weight_shape, [1 for _ in range(4 - dims)]).astype(int) -def equalize_model(model: onnx_pb.ModelProto): +def equalize_model(model: ModelProto): """ High-level API to perform Cross-Layer Equalization (CLE) on the given model. The model is equalized in place. diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/layer_output_utils.py b/TrainingExtensions/onnx/src/python/aimet_onnx/layer_output_utils.py index 8b1b6db6bfa..ed9a75087bc 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/layer_output_utils.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/layer_output_utils.py @@ -41,7 +41,7 @@ from typing import List, Dict, Tuple, Union import numpy as np import onnxruntime as ort -from onnx import onnx_pb +import onnx from aimet_common.utils import AimetLogger from aimet_common.layer_output_utils import SaveInputOutput, save_layer_output_names @@ -49,13 +49,20 @@ from aimet_onnx.quantsim import QuantizationSimModel from aimet_onnx.utils import create_input_dict, add_hook_to_get_activation +from packaging import version +# pylint: disable=no-name-in-module, ungrouped-imports +if version.parse(onnx.__version__) >= version.parse("1.14.0"): + from onnx import ModelProto +else: + from onnx.onnx_pb import ModelProto + logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.LayerOutputs) class LayerOutputUtil: """ Implementation to capture and save outputs of intermediate layers of a model (fp32/quantsim) """ - def __init__(self, model: onnx_pb.ModelProto, dir_path: str, device: int = 0): + def __init__(self, model: ModelProto, dir_path: str, device: int = 0): """ Constructor - It initializes the utility classes that captures and saves layer-outputs @@ -97,7 +104,7 @@ class LayerOutput: """ This class creates a layer-output name to layer-output dictionary. """ - def __init__(self, model: onnx_pb.ModelProto, providers: List, dir_path: str): + def __init__(self, model: ModelProto, providers: List, dir_path: str): """ Constructor - It initializes few lists that are required for capturing and naming layer-outputs. @@ -131,7 +138,7 @@ def get_outputs(self, input_dict: Dict) -> Dict[str, np.ndarray]: return dict(zip(self.sanitized_activation_names, activation_values)) @staticmethod - def get_activation_names(model: onnx_pb.ModelProto) -> List[str]: + def get_activation_names(model: ModelProto) -> List[str]: """ This function fetches the activation names (layer-output names) of the given onnx model. @@ -153,7 +160,7 @@ def get_activation_names(model: onnx_pb.ModelProto) -> List[str]: return activation_names @staticmethod - def register_activations(model: onnx_pb.ModelProto, activation_names: List): + def register_activations(model: ModelProto, activation_names: List): """ This function adds the intermediate activations into the model's ValueInfoProto so that they can be fetched via running the session. diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/meta/connectedgraph.py b/TrainingExtensions/onnx/src/python/aimet_onnx/meta/connectedgraph.py index 8eb5e17d4cb..f8b17ad7f30 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/meta/connectedgraph.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/meta/connectedgraph.py @@ -44,10 +44,9 @@ the tensors that are either input to the model (input, constant or parameter) or the result of an operation. Furthermore the graph representation is bi-directional.""" - from typing import List, Union, Dict -from onnx import onnx_pb from onnxruntime.quantization.onnx_quantizer import ONNXModel +import onnx from aimet_common.connected_graph.connectedgraph import ConnectedGraph as AimetCommonConnectedGraph, get_ordered_ops from aimet_common.utils import AimetLogger @@ -56,6 +55,13 @@ from aimet_onnx.meta.product import Product from aimet_onnx.utils import ParamUtils, retrieve_constant_input +from packaging import version +# pylint: disable=no-name-in-module, ungrouped-imports +if version.parse(onnx.__version__) >= version.parse("1.14.0"): + from onnx import ModelProto, NodeProto, TensorProto +else: + from onnx.onnx_pb import ModelProto, NodeProto, TensorProto + logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.ConnectedGraph) WEIGHT_INDEX = 1 @@ -72,7 +78,7 @@ class ConnectedGraph(AimetCommonConnectedGraph): either module or functional) as producers and consumers of tensors. Note that the graph has two kinds of nodes: operations and products.""" - def __init__(self, model: onnx_pb.ModelProto): + def __init__(self, model: ModelProto): """ :param: model: ONNX model to create connected graph from """ @@ -172,7 +178,7 @@ def check_if_node_has_predecessor(node): return input_ops @staticmethod - def _create_ir_op(node: onnx_pb.NodeProto) -> Op: + def _create_ir_op(node: NodeProto) -> Op: """ Creates connected graphs internal representation Op :param node: ONNX proto node for which Op needs to be created @@ -192,7 +198,7 @@ def _create_ir_op(node: onnx_pb.NodeProto) -> Op: return op - def _add_children_ops_to_op_queue(self, node: onnx_pb.NodeProto, op_queue: List): + def _add_children_ops_to_op_queue(self, node: NodeProto, op_queue: List): """ Utility function for adding all children of op to self._op_queue :param node: node whose children will be added to op_queue @@ -226,7 +232,7 @@ def _process_starting_ops(self, op_queue: List): self._create_and_link_product_for_inputs(node_name, input_tensor_name) @staticmethod - def check_if_param(node: onnx_pb.NodeProto, index: int) -> bool: + def check_if_param(node: NodeProto, index: int) -> bool: """ Checks if given tensor is a param @@ -278,7 +284,7 @@ def _create_and_link_product_for_inputs(self, consumer_node_name: str, input_ten consumer_op.add_input(product) product.add_consumer(consumer_op) - def _create_op_if_not_exists(self, node: onnx_pb.NodeProto): + def _create_op_if_not_exists(self, node: NodeProto): """ Creates a CG op for a node""" if node.name not in self._ops: op = self._create_ir_op(node) @@ -287,7 +293,7 @@ def _create_op_if_not_exists(self, node: onnx_pb.NodeProto): else: logger.debug("Op %s already exists", node.name) - def _create_and_link_product_if_not_exists(self, child_node: onnx_pb.NodeProto, parent_node: onnx_pb.NodeProto, + def _create_and_link_product_if_not_exists(self, child_node: NodeProto, parent_node: NodeProto, connecting_tensor_name: str): """ Create and link new product if it does not yet exist """ parent_module_name = parent_node.name @@ -508,7 +514,7 @@ def _create_param_products(self): """ Create products for parameters of select modules """ def create_and_connect_product(param_name: str, product_shape: List, my_op: Op, - param_tensor: onnx_pb.TensorProto, product_type: Union[str, None]): + param_tensor: TensorProto, product_type: Union[str, None]): """ Create product with given name, shape, and corresponding tensor. Connect product to my_op. """ product = Product(param_name, product_shape) @@ -577,7 +583,7 @@ def handle_default(my_op: Op): handler(op) -def get_op_attributes(node: onnx_pb.NodeProto, attribute_name: str): +def get_op_attributes(node: NodeProto, attribute_name: str): """ Gets attribute information for layer diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/meta/utils.py b/TrainingExtensions/onnx/src/python/aimet_onnx/meta/utils.py index e39bdf4b9db..4b3d2cd35b9 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/meta/utils.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/meta/utils.py @@ -36,10 +36,17 @@ # ============================================================================= """ Utilities for ONNX Connected Graph """ from typing import Dict, List -from onnx import onnx_pb +import onnx from aimet_onnx.meta.connectedgraph import ConnectedGraph +from packaging import version +# pylint: disable=no-name-in-module, ungrouped-imports +if version.parse(onnx.__version__) >= version.parse("1.14.0"): + from onnx import ModelProto +else: + from onnx.onnx_pb import ModelProto + ActivationTypes = ['Relu', 'Clip', 'Sigmoid', 'Tanh', 'PRelu', 'Softmax'] @@ -75,7 +82,7 @@ def get_param_shape_using_connected_graph(connected_graph: ConnectedGraph, param return param.shape return None -def get_module_act_func_pair(model: onnx_pb.ModelProto) -> Dict[str, str]: +def get_module_act_func_pair(model: ModelProto) -> Dict[str, str]: """ For given model, returns dictionary of module to immediate following activation function else maps module to None. @@ -108,7 +115,7 @@ def get_module_act_func_pair(model: onnx_pb.ModelProto) -> Dict[str, str]: return module_act_func_pair -def get_ordered_ops(model: onnx_pb.ModelProto) -> List: +def get_ordered_ops(model: ModelProto) -> List: """ Gets list of ordered ops diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py index 5f35b028c58..9cd1dc65222 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py @@ -40,7 +40,9 @@ from typing import Dict, List, Union import json import numpy as np -from onnx import helper, onnx_pb +import onnx + +from onnx import helper import onnxruntime as ort from onnxruntime import SessionOptions, GraphOptimizationLevel, InferenceSession from onnxruntime.quantization.onnx_quantizer import ONNXModel @@ -58,6 +60,13 @@ from aimet_onnx.quantsim_config.quantsim_config import QuantSimConfigurator from aimet_onnx.utils import make_dummy_input, add_hook_to_get_activation, remove_activation_hooks +from packaging import version +# pylint: disable=no-name-in-module, ungrouped-imports +if version.parse(onnx.__version__) >= version.parse("1.14.0"): + from onnx import ModelProto +else: + from onnx.onnx_pb import ModelProto + WORKING_DIR = '/tmp/quantsim/' op_types_to_ignore = ["branch", "Flatten", "Gather", "Reshape", "Shape", "Unsqueeze", "Squeeze", "Split", @@ -74,7 +83,7 @@ class QuantizationSimModel: # pylint: disable=too-many-arguments # pylint: disable=too-many-instance-attributes def __init__(self, - model: onnx_pb.ModelProto, + model: ModelProto, dummy_input: Dict[str, np.ndarray] = None, quant_scheme: QuantScheme = QuantScheme.post_training_tf_enhanced, rounding_mode: str = 'nearest', @@ -339,7 +348,7 @@ def _insert_activation_quantization_nodes(self): ) @staticmethod - def build_session(model: onnx_pb.ModelProto, providers: List): + def build_session(model: ModelProto, providers: List): """ Build and return onnxruntime inference session diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim_config/quantsim_config.py b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim_config/quantsim_config.py index 84765fcf2d9..2bb944ec3d3 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim_config/quantsim_config.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim_config/quantsim_config.py @@ -38,7 +38,8 @@ from abc import abstractmethod from typing import List, Dict, Tuple -from onnx import onnx_pb +import onnx + from aimet_common.defs import QuantizationDataType from aimet_common.graph_searcher import GraphSearcher from aimet_common.connected_graph.connectedgraph_utils import get_all_input_ops, get_all_output_ops @@ -51,6 +52,13 @@ from aimet_onnx.utils import get_product_name_from_quantized_name from aimet_onnx.qc_quantize_op import OpMode, QcQuantizeOp +from packaging import version +# pylint: disable=no-name-in-module, ungrouped-imports +if version.parse(onnx.__version__) >= version.parse("1.14.0"): + from onnx import ModelProto, NodeProto +else: + from onnx.onnx_pb import ModelProto, NodeProto + logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant) @@ -68,7 +76,7 @@ def __init__(self): class SupergroupConfigCallback(AimetCommonSupergroupConfigCallback): """ Class acting as a callback for when supergroups are found """ - def __init__(self, model: onnx_pb.ModelProto, op_to_quantizers: Dict): + def __init__(self, model: ModelProto, op_to_quantizers: Dict): super().__init__() self._model = model self._op_to_quantizers = op_to_quantizers @@ -86,7 +94,7 @@ class QuantSimConfigurator(AimetCommonQuantSimConfigurator): """ Class for parsing and applying quantsim configurations from json config file """ - def __init__(self, model: onnx_pb.ModelProto, conn_graph: ConnectedGraph, config_file: str, quantsim_output_bw: int, + def __init__(self, model: ModelProto, conn_graph: ConnectedGraph, config_file: str, quantsim_output_bw: int, quantsim_param_bw: int, quantsim_data_type: QuantizationDataType = QuantizationDataType.int): super().__init__(config_file, quantsim_data_type, quantsim_output_bw, quantsim_param_bw) @@ -521,10 +529,10 @@ def __init__(self, op_type_supported_kernels: dict, op_type_pcq: dict): assert ConfigDictKeys.DEFAULTS in self.op_type_pcq @abstractmethod - def generate(self, op: onnx_pb.NodeProto, op_type: str) -> dict: + def generate(self, op: NodeProto, op_type: str) -> dict: """ generate the config for the given op """ - def _generate_pcq(self, op: onnx_pb.NodeProto) -> bool: + def _generate_pcq(self, op: NodeProto) -> bool: """ Helper function to generate the pcq field :param op: op instance to generate the pcq value for @@ -549,7 +557,7 @@ class DefaultOpInstanceConfigGenerator(OpInstanceConfigGenerator): Default implementation of OpInstanceConfigGenerator """ - def generate(self, op: onnx_pb.NodeProto, op_type: str) -> Tuple[dict, bool]: + def generate(self, op: NodeProto, op_type: str) -> Tuple[dict, bool]: """ :param op: op to generate the specialized config :param op_type: Type str retrieved from CG op diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/utils.py b/TrainingExtensions/onnx/src/python/aimet_onnx/utils.py index 181ff697943..30a55968516 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/utils.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/utils.py @@ -42,16 +42,23 @@ import pickle import numpy as np import onnx -from onnx import onnx_pb, helper, numpy_helper, mapping +from onnx import helper, numpy_helper, mapping from aimet_common.utils import AimetLogger +from packaging import version +# pylint: disable=no-name-in-module, ungrouped-imports +if version.parse(onnx.__version__) >= version.parse("1.14.0"): + from onnx import NodeProto, TensorProto, ModelProto, GraphProto, ValueInfoProto +else: + from onnx.onnx_pb import NodeProto, TensorProto, ModelProto, GraphProto, ValueInfoProto + logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils) OP_TYPES_WITH_PARAMS = ['Conv', 'Gemm', 'ConvTranspose', 'BatchNormalization', 'MatMul', 'Transpose'] -def remove_nodes_with_type(node_type: str, onnx_graph: onnx.onnx_pb.GraphProto): +def remove_nodes_with_type(node_type: str, onnx_graph: onnx.GraphProto): """ Remove specific type of nodes from graph @@ -73,7 +80,7 @@ def remove_nodes_with_type(node_type: str, onnx_graph: onnx.onnx_pb.GraphProto): node.output[0] = outputs.name -def remove_node(node: onnx_pb.ModelProto, onnx_graph: onnx.onnx_pb.GraphProto): +def remove_node(node: ModelProto, onnx_graph: onnx.GraphProto): """ Remove a specific node from graph along with associated initializers @@ -101,7 +108,7 @@ def remove_node(node: onnx_pb.ModelProto, onnx_graph: onnx.onnx_pb.GraphProto): onnx_graph.initializer.remove(item) -def transpose_tensor(t: onnx.onnx_ml_pb2.TensorProto, axes: Union[List, Tuple]) -> onnx.onnx_ml_pb2.TensorProto: +def transpose_tensor(t: TensorProto, axes: Union[List, Tuple]) -> TensorProto: """ Permutes the axes of a given array using numpy.transpose @@ -116,7 +123,7 @@ def transpose_tensor(t: onnx.onnx_ml_pb2.TensorProto, axes: Union[List, Tuple]) return numpy_helper.from_array(np.transpose(t_np, axes), name=t.name) -def replace_node_with_op(node_type: str, new_type: str, onnx_graph: onnx.onnx_pb.GraphProto): +def replace_node_with_op(node_type: str, new_type: str, onnx_graph: onnx.GraphProto): """ Replace the given op type of nodes to new op type @@ -130,7 +137,7 @@ def replace_node_with_op(node_type: str, new_type: str, onnx_graph: onnx.onnx_pb node.op_type = new_type -def get_node_attribute(node: onnx_pb.NodeProto, name: str): +def get_node_attribute(node: NodeProto, name: str): """ Return the value of a node's attribute specified by its name @@ -144,7 +151,7 @@ def get_node_attribute(node: onnx_pb.NodeProto, name: str): return None -def get_weights(name: str, onnx_graph: onnx.onnx_pb.GraphProto) -> bytes: +def get_weights(name: str, onnx_graph: onnx.GraphProto) -> bytes: """ Return the weights by given name :param name, name of the weights to find @@ -158,7 +165,7 @@ def get_weights(name: str, onnx_graph: onnx.onnx_pb.GraphProto) -> bytes: return None -def get_ordered_dict_of_nodes(onnx_graph: onnx.onnx_pb.GraphProto) -> Dict: +def get_ordered_dict_of_nodes(onnx_graph: onnx.GraphProto) -> Dict: """ Return the ordered list of nodes @@ -172,7 +179,7 @@ def get_ordered_dict_of_nodes(onnx_graph: onnx.onnx_pb.GraphProto) -> Dict: return ordered_dict -def make_dummy_input(model: onnx_pb.ModelProto, dynamic_size: int = 1) -> Dict[str, np.ndarray]: +def make_dummy_input(model: ModelProto, dynamic_size: int = 1) -> Dict[str, np.ndarray]: """ Create a dummy input based on the model input types and shapes @@ -196,7 +203,7 @@ def make_dummy_input(model: onnx_pb.ModelProto, dynamic_size: int = 1) -> Dict[s return input_dict -def replace_relu6_with_relu(model: onnx_pb.ModelProto): +def replace_relu6_with_relu(model: ModelProto): """ Replace relu6 op with relu op @@ -228,7 +235,7 @@ def replace_relu6_with_relu(model: onnx_pb.ModelProto): model.add_node(relu_node) -def check_if_clip_node_minimum_is_zero(node: onnx_pb.NodeProto, model: onnx_pb.ModelProto): +def check_if_clip_node_minimum_is_zero(node: NodeProto, model: ModelProto): """ Check if the clip node's minimum is 0 @@ -247,7 +254,7 @@ def check_if_clip_node_minimum_is_zero(node: onnx_pb.NodeProto, model: onnx_pb.M return False -def add_hook_to_get_activation(model: onnx_pb.ModelProto, name: str) -> onnx_pb.ValueInfoProto: +def add_hook_to_get_activation(model: ModelProto, name: str) -> ValueInfoProto: """ Adds a given activation to the model output :param model: The model to add the hook to @@ -260,8 +267,8 @@ def add_hook_to_get_activation(model: onnx_pb.ModelProto, name: str) -> onnx_pb. return val_info -def remove_activation_hooks(model: onnx_pb.ModelProto, - hooks: Union[List[onnx_pb.ValueInfoProto], onnx_pb.ValueInfoProto]): +def remove_activation_hooks(model: ModelProto, + hooks: Union[List[ValueInfoProto], ValueInfoProto]): """ Removes activation hooks from the model output :param model: The model from which to remove the hooks @@ -273,7 +280,7 @@ def remove_activation_hooks(model: onnx_pb.ModelProto, model.graph.output.remove(hook) -def get_graph_intermediate_activations(graph: onnx_pb.GraphProto) -> List[str]: +def get_graph_intermediate_activations(graph: GraphProto) -> List[str]: """ Returns the names of all activations within a graph that are used as the input to another node :param graph: The graph for which to retrieve the activations @@ -294,7 +301,7 @@ def get_graph_intermediate_activations(graph: onnx_pb.GraphProto) -> List[str]: class ParamUtils: """ Param utilities """ @staticmethod - def get_shape(model: onnx_pb.ModelProto, node: onnx_pb.NodeProto, param_index: int) -> List: + def get_shape(model: ModelProto, node: NodeProto, param_index: int) -> List: """ Returns a list of shape for the param specifies :param model: ONNX model @@ -313,7 +320,7 @@ def get_shape(model: onnx_pb.ModelProto, node: onnx_pb.NodeProto, param_index: i return None @staticmethod - def get_param(model: onnx_pb.ModelProto, node: onnx_pb.NodeProto, param_index: int) -> onnx_pb.TensorProto: + def get_param(model: ModelProto, node: NodeProto, param_index: int) -> TensorProto: """ Returns the param tensor :param model: ONNX model @@ -342,8 +349,8 @@ def get_product_name_from_quantized_name(quantized_name: str): return None -def retrieve_constant_input(node: onnx_pb.NodeProto, model: onnx_pb.ModelProto, index: int - ) -> Tuple[onnx_pb.TensorProto, bool]: +def retrieve_constant_input(node: NodeProto, model: ModelProto, index: int + ) -> Tuple[TensorProto, bool]: """ Retrieves node input at the specified index if the input has a corresponding initializer in model.graph.initializer and is separated from node by no more than one Transpose operation. @@ -413,7 +420,7 @@ def _cache_model_inputs(self, data_loader): logger.info('Caching %d batches from data loader at path location: %s', self._num_batches, self._path) -def create_input_dict(model: onnx_pb.ModelProto, input_batch: Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray]]) -> Dict: +def create_input_dict(model: ModelProto, input_batch: Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray]]) -> Dict: """ Creates input dictionary (input name to input value map) for session.run