Skip to content

Commit

Permalink
Update ONNX imports to support version upgrade
Browse files Browse the repository at this point in the history
Signed-off-by: Ashvin Kumar <quic_ashvkuma@quicinc.com>
  • Loading branch information
quic-ashvkuma authored Oct 31, 2023
1 parent 729d89d commit a346c04
Show file tree
Hide file tree
Showing 11 changed files with 157 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,20 @@
from typing import Tuple, List, Dict, Union

import numpy as np
from onnx import onnx_pb
import onnxruntime as ort
import onnx

from aimet_common.utils import AimetLogger
from aimet_onnx.quantsim import QuantizationSimModel
from aimet_onnx.utils import add_hook_to_get_activation, remove_activation_hooks, create_input_dict

from packaging import version
# pylint: disable=no-name-in-module, ungrouped-imports
if version.parse(onnx.__version__) >= version.parse("1.14.0"):

This comment has been minimized.

Copy link
@FelixSchwarz

FelixSchwarz Nov 5, 2023

Contributor

just a quick note from some random dude on the internet:
To me a more pythonic way would be to use try: ... except ImportError: ... instead of checking versions.

from onnx import ModelProto
else:
from onnx.onnx_pb import ModelProto

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)


Expand All @@ -56,7 +63,7 @@ class ActivationSampler:
collect the module's output and input activation data respectively
"""
def __init__(self, orig_op: str, quant_op: str,
orig_model: onnx_pb.ModelProto, quant_model: QuantizationSimModel, use_cuda: bool,
orig_model: ModelProto, quant_model: QuantizationSimModel, use_cuda: bool,
device: int = 0):
"""
:param orig_op: Single un quantized op from the original session
Expand Down Expand Up @@ -132,7 +139,7 @@ class ModuleData:
Collect input and output data to and from module
"""

def __init__(self, model: onnx_pb.ModelProto, node_name: str, providers: List):
def __init__(self, model: ModelProto, node_name: str, providers: List):
"""
:param session: ONNX session
:param node: Module reference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from typing import Union, Tuple, Dict
import numpy as np
import onnx
from onnx import onnx_pb, numpy_helper
from onnx import numpy_helper
import torch
import torch.nn.functional as functional
from torch.utils.data import Dataset
Expand All @@ -56,6 +56,13 @@
from aimet_torch.adaround.adaround_tensor_quantizer import AdaroundTensorQuantizer
from aimet_torch.adaround.adaround_optimizer import AdaroundOptimizer as TorchAdaroundOptimizer

from packaging import version
# pylint: disable=no-name-in-module, ungrouped-imports
if version.parse(onnx.__version__) >= version.parse("1.14.0"):
from onnx import ModelProto
else:
from onnx.onnx_pb import ModelProto

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)
BATCH_SIZE = 32
EMPIRICAL_THRESHOLD = 3 / 4
Expand All @@ -70,7 +77,7 @@ class AdaroundOptimizer:
"""
@classmethod
def adaround_module(cls, module: ModuleInfo, quantized_input_name: str,
orig_model: onnx_pb.ModelProto, quant_model: QuantizationSimModel,
orig_model: ModelProto, quant_model: QuantizationSimModel,
act_func: Union[torch.nn.Module, None], cached_dataset: Dataset,
opt_params: AdaroundHyperParameters, param_to_adaround_tensor_quantizer: Dict,
use_cuda: bool, device: int = 0):
Expand Down Expand Up @@ -100,7 +107,7 @@ def adaround_module(cls, module: ModuleInfo, quantized_input_name: str,

@classmethod
def _optimize_rounding(cls, module: ModuleInfo, quantized_input_name,
orig_model: onnx_pb.ModelProto, quant_model: QuantizationSimModel,
orig_model: ModelProto, quant_model: QuantizationSimModel,
act_func: Union[None, str], cached_dataset: Dataset,
opt_params: AdaroundHyperParameters, param_to_adaround_tensor_quantizer: Dict,
use_cuda: bool, device: int = 0):
Expand Down
11 changes: 9 additions & 2 deletions TrainingExtensions/onnx/src/python/aimet_onnx/adaround/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,17 @@
""" Utilities for Adaround ONNX """
from typing import Dict
from collections import defaultdict
from onnx import onnx_pb
import onnx

from aimet_onnx.meta.connectedgraph import ConnectedGraph

from packaging import version
# pylint: disable=no-name-in-module, ungrouped-imports
if version.parse(onnx.__version__) >= version.parse("1.14.0"):
from onnx import ModelProto
else:
from onnx.onnx_pb import ModelProto

class ModuleInfo:
""" Class object containing information about a module """
def __init__(self):
Expand All @@ -55,7 +62,7 @@ class ModelData:
"""
Class to collect data for each module of a class
"""
def __init__(self, model: onnx_pb.ModelProto):
def __init__(self, model: ModelProto):
"""
:param model: ONNX Model
"""
Expand Down
36 changes: 22 additions & 14 deletions TrainingExtensions/onnx/src/python/aimet_onnx/batch_norm_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@

from typing import Dict, List, Tuple
import contextlib
from onnx import onnx_pb, numpy_helper
import numpy as np
import onnx
from onnx import numpy_helper

from aimet_common.bias_correction import ConvBnPatternHandler
from aimet_common.graph_pattern_matcher import PatternType
Expand All @@ -53,6 +54,13 @@
from aimet_onnx.meta.operations import Op
from aimet_onnx.utils import get_node_attribute, remove_node, transpose_tensor, ParamUtils, retrieve_constant_input

from packaging import version
# pylint: disable=no-name-in-module, ungrouped-imports
if version.parse(onnx.__version__) >= version.parse("1.14.0"):
from onnx import NodeProto, TensorProto, ModelProto
else:
from onnx.onnx_pb import NodeProto, TensorProto, ModelProto

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.BatchNormFolding)

ConvType = ['Conv', 'ConvTranspose']
Expand Down Expand Up @@ -103,8 +111,8 @@ def _find_conv_bn_pairs(connected_graph: ConnectedGraph) -> Dict:


def find_all_batch_norms_to_fold(connected_graph: ConnectedGraph,
) -> Tuple[List[Tuple[onnx_pb.NodeProto, onnx_pb.NodeProto]],
List[Tuple[onnx_pb.NodeProto, onnx_pb.NodeProto]]]:
) -> Tuple[List[Tuple[NodeProto, NodeProto]],
List[Tuple[NodeProto, NodeProto]]]:
"""
Find all possible batch norm layers that can be folded. Returns a list of pairs such that (bn, layer)
means bn will be forward-folded into layer and (layer, bn) means bn will be backward-folded into layer
Expand Down Expand Up @@ -164,7 +172,7 @@ def get_ordered_conv_linears(conn_graph: ConnectedGraph) -> List[Op]:
return ordered_convs


def is_valid_bn_fold(conv_linear: onnx_pb.NodeProto, model: onnx_pb.ModelProto, fold_backward: bool) -> bool:
def is_valid_bn_fold(conv_linear: NodeProto, model: ModelProto, fold_backward: bool) -> bool:
"""
Determine if a given layer can successfully absorb a BatchNorm given the layer type and parameters
:param conv_linear: The Conv/Linear layer to fold a BatchNorm into.
Expand Down Expand Up @@ -193,7 +201,7 @@ def is_valid_bn_fold(conv_linear: onnx_pb.NodeProto, model: onnx_pb.ModelProto,
return valid


def fold_all_batch_norms_to_weight(model: onnx_pb.ModelProto) -> [List]:
def fold_all_batch_norms_to_weight(model: ModelProto) -> [List]:
"""
Fold all possible batch_norm layers in a model into the weight of the corresponding conv layers
Expand All @@ -218,9 +226,9 @@ def fold_all_batch_norms_to_weight(model: onnx_pb.ModelProto) -> [List]:
return conv_bns, bn_convs


def _fold_to_weight(model: onnx_pb.ModelProto,
conv_linear: onnx_pb.NodeProto,
bn: onnx_pb.NodeProto,
def _fold_to_weight(model: ModelProto,
conv_linear: NodeProto,
bn: NodeProto,
fold_backward: bool):
"""
Fold BatchNorm into the weight and bias of the given layer.
Expand Down Expand Up @@ -273,7 +281,7 @@ def _fold_to_weight(model: onnx_pb.ModelProto,
return bn_layer


def _matmul_to_gemm(node: onnx_pb.NodeProto, model: onnx_pb.ModelProto):
def _matmul_to_gemm(node: NodeProto, model: ModelProto):
"""
Convert MatMul node to Gemm and initialize bias to zeros
Expand All @@ -298,8 +306,8 @@ def _matmul_to_gemm(node: onnx_pb.NodeProto, model: onnx_pb.ModelProto):
node.input.append(bias_name)


def _call_mo_batch_norm_fold(weight: onnx_pb.TensorProto,
bias: onnx_pb.TensorProto,
def _call_mo_batch_norm_fold(weight: TensorProto,
bias: TensorProto,
bn_params: libpymo.BNParams,
fold_backward: bool):
"""
Expand Down Expand Up @@ -328,7 +336,7 @@ def _call_mo_batch_norm_fold(weight: onnx_pb.TensorProto,
weight.raw_data = np.asarray(weight_tensor.data, dtype=np.float32).tobytes()


def get_bn_params(model: onnx_pb.ModelProto, bn: onnx_pb.NodeProto, channels: int) -> libpymo.BNParams:
def get_bn_params(model: ModelProto, bn: NodeProto, channels: int) -> libpymo.BNParams:
"""
Returns the populated libpymo.BNParams object for the given BatchNormalization layer with
parameters repeated if necessary.
Expand All @@ -355,7 +363,7 @@ def get_bn_params(model: onnx_pb.ModelProto, bn: onnx_pb.NodeProto, channels: in
return bn_params


def copy_bn_params_to_bn_layer(bn: onnx_pb.NodeProto, bn_params: libpymo.BNParams) -> BNLayer:
def copy_bn_params_to_bn_layer(bn: NodeProto, bn_params: libpymo.BNParams) -> BNLayer:
"""
Copies bn params to a BN layer which can be used later by High bias absorption
:param bn: BN layer
Expand Down Expand Up @@ -390,7 +398,7 @@ def _expand_shape_to_4d(weight_tensor: libpymo.TensorParams):
weight_tensor.shape = orig_shape


def get_input_output_channels(node: onnx_pb.NodeProto, model: onnx_pb.ModelProto) -> Tuple[int, int]:
def get_input_output_channels(node: NodeProto, model: ModelProto) -> Tuple[int, int]:
"""
Find the input and output channels of a given layer.
:param node: The node to find the input/output channels of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@

from typing import Tuple, List, Union
import numpy as np
from onnx import onnx_pb, numpy_helper
import onnx
from onnx import numpy_helper
from onnxruntime.quantization.onnx_quantizer import ONNXModel

from aimet_common.utils import AimetLogger
Expand All @@ -58,6 +59,13 @@
from aimet_onnx.utils import transpose_tensor, ParamUtils, get_node_attribute, replace_relu6_with_relu
from aimet_onnx.batch_norm_fold import BNLayer, fold_all_batch_norms_to_weight

from packaging import version
# pylint: disable=no-name-in-module, ungrouped-imports
if version.parse(onnx.__version__) >= version.parse("1.14.0"):
from onnx import NodeProto, ModelProto
else:
from onnx.onnx_pb import NodeProto, ModelProto

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)

ClsSet = Union[Tuple['Conv', 'Conv'],
Expand All @@ -82,7 +90,7 @@ class CrossLayerScaling(CLS):
"""
Scales a model's layers to equalize the weights between consecutive layers
"""
def __init__(self, model: onnx_pb.ModelProto):
def __init__(self, model: ModelProto):
"""
:param model: ONNX model
"""
Expand Down Expand Up @@ -121,7 +129,7 @@ def scale_model(self) -> List[ClsSetInfo]:

return cls_set_info_list

def _populate_libpymo_params(self, module: onnx_pb.NodeProto,
def _populate_libpymo_params(self, module: NodeProto,
layer_param: libpymo.EqualizationParams):
"""
Populates libpymo weight parameter
Expand Down Expand Up @@ -158,7 +166,7 @@ def _pack_params_for_conv(self,
prev_layer_params.isBiasNone = True

def _update_weight_for_layer_from_libpymo_obj(self, layer_param: libpymo.EqualizationParams,
module: onnx_pb.NodeProto):
module: NodeProto):
"""
Update weight parameter from libpymo object
"""
Expand Down Expand Up @@ -257,7 +265,7 @@ class HighBiasFold(HBF):
"""
Code to apply the high-bias-fold technique to a model
"""
def __init__(self, model: onnx_pb.ModelProto):
def __init__(self, model: ModelProto):
self._model = model

def _check_if_bias_is_none(self, layer: Op) -> bool:
Expand Down Expand Up @@ -307,7 +315,7 @@ def _pack_previous_and_current_layer_params(self, cls_pair_info: ClsSetInfo.ClsS
curr_layer_params.weightShape = get_weight_dimensions(np.array(weight.dims))

def _update_bias_for_layer_from_libpymo_obj(self, layer_param: libpymo.LayerParams,
module: onnx_pb.NodeProto):
module: NodeProto):
"""
Update bias parameter from libpymo object
"""
Expand Down Expand Up @@ -340,7 +348,7 @@ def get_weight_dimensions(weight_shape: np.array) -> np.array:
return np.append(weight_shape, [1 for _ in range(4 - dims)]).astype(int)


def equalize_model(model: onnx_pb.ModelProto):
def equalize_model(model: ModelProto):
"""
High-level API to perform Cross-Layer Equalization (CLE) on the given model. The model is equalized in place.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,28 @@
from typing import List, Dict, Tuple, Union
import numpy as np
import onnxruntime as ort
from onnx import onnx_pb
import onnx

from aimet_common.utils import AimetLogger
from aimet_common.layer_output_utils import SaveInputOutput, save_layer_output_names

from aimet_onnx.quantsim import QuantizationSimModel
from aimet_onnx.utils import create_input_dict, add_hook_to_get_activation

from packaging import version
# pylint: disable=no-name-in-module, ungrouped-imports
if version.parse(onnx.__version__) >= version.parse("1.14.0"):
from onnx import ModelProto
else:
from onnx.onnx_pb import ModelProto

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.LayerOutputs)


class LayerOutputUtil:
""" Implementation to capture and save outputs of intermediate layers of a model (fp32/quantsim) """

def __init__(self, model: onnx_pb.ModelProto, dir_path: str, device: int = 0):
def __init__(self, model: ModelProto, dir_path: str, device: int = 0):
"""
Constructor - It initializes the utility classes that captures and saves layer-outputs
Expand Down Expand Up @@ -97,7 +104,7 @@ class LayerOutput:
"""
This class creates a layer-output name to layer-output dictionary.
"""
def __init__(self, model: onnx_pb.ModelProto, providers: List, dir_path: str):
def __init__(self, model: ModelProto, providers: List, dir_path: str):
"""
Constructor - It initializes few lists that are required for capturing and naming layer-outputs.
Expand Down Expand Up @@ -131,7 +138,7 @@ def get_outputs(self, input_dict: Dict) -> Dict[str, np.ndarray]:
return dict(zip(self.sanitized_activation_names, activation_values))

@staticmethod
def get_activation_names(model: onnx_pb.ModelProto) -> List[str]:
def get_activation_names(model: ModelProto) -> List[str]:
"""
This function fetches the activation names (layer-output names) of the given onnx model.
Expand All @@ -153,7 +160,7 @@ def get_activation_names(model: onnx_pb.ModelProto) -> List[str]:
return activation_names

@staticmethod
def register_activations(model: onnx_pb.ModelProto, activation_names: List):
def register_activations(model: ModelProto, activation_names: List):
"""
This function adds the intermediate activations into the model's ValueInfoProto so that they can be fetched via
running the session.
Expand Down
Loading

0 comments on commit a346c04

Please sign in to comment.