Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom-op support in ONNX AdaRound #2754

Merged
merged 2 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,15 @@ class ActivationSampler:
"""
def __init__(self, orig_op: str, quant_op: str,
orig_model: ModelProto, quant_model: QuantizationSimModel, use_cuda: bool,
device: int = 0):
device: int = 0, user_onnx_libs: List[str] = None):
"""
:param orig_op: Single un quantized op from the original session
:param quant_op: Corresponding quant op from the Quant sim session
:param orig_model: Session with the original model
:param quant_model: Session with the model with quantization simulations ops
:param use_cuda: If we should use cuda
:param device: CUDA device ID
:param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries
:return: Input data to quant op, Output data from original op
"""
self._org_model = orig_model
Expand All @@ -84,8 +85,8 @@ def __init__(self, orig_op: str, quant_op: str,
else:
self.providers = ['CPUExecutionProvider']

self._orig_module_collector = ModuleData(orig_model, orig_op, self.providers)
self._quant_module_collector = ModuleData(quant_model, quant_op, self.providers)
self._orig_module_collector = ModuleData(orig_model, orig_op, self.providers, user_onnx_libs)
self._quant_module_collector = ModuleData(quant_model, quant_op, self.providers, user_onnx_libs)

def sample_and_place_all_acts_on_cpu(self, dataset) -> Tuple:
"""
Expand Down Expand Up @@ -139,15 +140,17 @@ class ModuleData:
Collect input and output data to and from module
"""

def __init__(self, model: ModelProto, node_name: str, providers: List):
def __init__(self, model: ModelProto, node_name: str, providers: List, user_onnx_libs: List[str] = None):
"""
:param session: ONNX session
:param node: Module reference
:param providers: CPU/GPU execution providers
:param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries
"""
self._model = model
self._module_name = node_name
self._providers = providers
self._user_onnx_libs = user_onnx_libs

def collect_inp_out_data(self, model_input: Dict[str, List[np.ndarray]],
collect_input: bool, collect_output: bool) -> Union[Tuple[None, List], Tuple[List, None]]:
Expand All @@ -161,7 +164,7 @@ def collect_inp_out_data(self, model_input: Dict[str, List[np.ndarray]],
"""

handle = add_hook_to_get_activation(self._model.model, self._module_name)
sess = QuantizationSimModel.build_session(self._model.model, self._providers)
sess = QuantizationSimModel.build_session(self._model.model, self._providers, self._user_onnx_libs)
outputs = sess.run([self._module_name], model_input)
remove_activation_hooks(self._model.model, handle)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

""" Adaround optimizer """

from typing import Union, Tuple, Dict
from typing import Union, Tuple, Dict, List
import numpy as np
import onnx
from onnx import numpy_helper
Expand Down Expand Up @@ -80,7 +80,7 @@ def adaround_module(cls, module: ModuleInfo, quantized_input_name: str,
orig_model: ModelProto, quant_model: QuantizationSimModel,
act_func: Union[torch.nn.Module, None], cached_dataset: Dataset,
opt_params: AdaroundHyperParameters, param_to_adaround_tensor_quantizer: Dict,
use_cuda: bool, device: int = 0):
use_cuda: bool, device: int = 0, user_onnx_libs: List[str] = None):
"""
Adaround module

Expand All @@ -95,12 +95,13 @@ def adaround_module(cls, module: ModuleInfo, quantized_input_name: str,
:param param_to_adaround_tensor_quantizer: Param name to adaround tensor quantizer dictionary
:param use_cuda: If we should use cuda
:param device: CUDA device ID
:param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries
"""
# pylint: disable=too-many-arguments

# Optimize weight rounding
cls._optimize_rounding(module, quantized_input_name, orig_model, quant_model, act_func, cached_dataset,
opt_params, param_to_adaround_tensor_quantizer, use_cuda, device)
opt_params, param_to_adaround_tensor_quantizer, use_cuda, device, user_onnx_libs)

# After optimization, set the optimized layer's rounding mode to "Hard rounding"
param_to_adaround_tensor_quantizer[module.params['weight'].name].use_soft_rounding = False
Expand All @@ -111,7 +112,7 @@ def _optimize_rounding(cls, module: ModuleInfo, quantized_input_name,
orig_model: ModelProto, quant_model: QuantizationSimModel,
act_func: Union[None, str], cached_dataset: Dataset,
opt_params: AdaroundHyperParameters, param_to_adaround_tensor_quantizer: Dict,
use_cuda: bool, device: int = 0):
use_cuda: bool, device: int = 0, user_onnx_libs: List[str] = None):
"""
Optimizes the weight rounding of quantized wrapper module
:param module: Original module
Expand All @@ -122,6 +123,7 @@ def _optimize_rounding(cls, module: ModuleInfo, quantized_input_name,
:param cached_dataset: Cached dataset
:param opt_params: Optimization parameters
:param param_to_adaround_tensor_quantizer: Param name to adaround tensor quantizer dictionary
:param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries
"""
# pylint: disable=too-many-locals, too-many-arguments
adaround_quantizer = param_to_adaround_tensor_quantizer[module.params['weight'].name]
Expand All @@ -144,7 +146,7 @@ def _optimize_rounding(cls, module: ModuleInfo, quantized_input_name,
# Check if we can cache intermediate activation data.
model_inputs = cached_dataset[0]
act_sampler = ActivationSampler(module.outputs[0], quantized_input_name, orig_model, quant_model,
use_cuda, device)
use_cuda, device, user_onnx_libs)
inp_data, out_data = act_sampler.sample_acts(create_input_dict(orig_model.model, model_inputs))
inp_data_torch, out_data_torch = torch.from_numpy(inp_data[0]), torch.from_numpy(out_data[0])
use_cache_acts_data = TorchAdaroundOptimizer._can_cache_acts_data(len(cached_dataset), inp_data_torch.shape,
Expand All @@ -159,11 +161,11 @@ def _optimize_rounding(cls, module: ModuleInfo, quantized_input_name,
if use_cache_acts_data and AdaroundOptimizer.enable_caching_acts_data():
logger.debug("Caching intermediate activations data for optimization.")
all_inp_data, all_orig_out_data = act_sampler.sample_and_place_all_acts_on_cpu(cached_dataset)
all_inp_data, all_out_data = torch.from_numpy(all_inp_data[0]), \
all_inp_data, all_orig_out_data = torch.from_numpy(all_inp_data[0]), \
torch.from_numpy(all_orig_out_data[0])
# Try to put all cached activations data on GPU for faster optimization if possible.
if use_cuda:
all_inp_data, all_orig_out_data = TorchAdaroundOptimizer._place_cached_acts_data(all_inp_data, all_out_data,
all_inp_data, all_orig_out_data = TorchAdaroundOptimizer._place_cached_acts_data(all_inp_data, all_orig_out_data,
torch_device)

for iteration in range(opt_params.num_iterations):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def apply_adaround(cls, model: onnx_pb.ModelProto, params: AdaroundParameters,
param_bw_override_list: List[Tuple[str, int]] = None,
ignore_quant_ops_list: List[str] = None,
default_quant_scheme: QuantScheme = QuantScheme.post_training_tf_enhanced,
default_config_file: str = None, use_cuda: bool = True, device: int = 0) -> onnx_pb.ModelProto:
default_config_file: str = None, use_cuda: bool = True, device: int = 0,
user_onnx_libs: List[str] = None) -> onnx_pb.ModelProto:
"""
Returns model with optimized weight rounding of every module (Conv and Linear) and also saves the
corresponding quantization encodings to a separate JSON-formatted file that can then be imported by
Expand All @@ -136,6 +137,7 @@ def apply_adaround(cls, model: onnx_pb.ModelProto, params: AdaroundParameters,
:param default_config_file: Default configuration file for model quantizers
:param use_cuda: If we should use cuda
:param device: CUDA device ID
:param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries
:return: Model with Adarounded weights and saves corresponding parameter encodings JSON file at provided path
"""
# pylint: disable=too-many-arguments
Expand All @@ -144,7 +146,8 @@ def apply_adaround(cls, model: onnx_pb.ModelProto, params: AdaroundParameters,
model = ONNXModel(model)
quant_sim = QuantizationSimModel(copy.deepcopy(model), quant_scheme=default_quant_scheme,
default_param_bw=default_param_bw,
config_file=default_config_file)
config_file=default_config_file,
user_onnx_libs=user_onnx_libs)

# For the params in the param_bw_override_list, override the default parameter bitwidths in the QuantSim
if param_bw_override_list:
Expand All @@ -156,11 +159,12 @@ def apply_adaround(cls, model: onnx_pb.ModelProto, params: AdaroundParameters,
# Compute only param encodings
cls._compute_param_encodings(quant_sim, params)

return cls._apply_adaround(quant_sim, model, params, path, filename_prefix, use_cuda, device)
return cls._apply_adaround(quant_sim, model, params, path, filename_prefix, use_cuda, device, user_onnx_libs)

@classmethod
def _apply_adaround(cls, quant_sim: QuantizationSimModel, model: onnx_pb.ModelProto, params: AdaroundParameters,
path: str, filename_prefix: str, use_cuda: bool = True, device: int = 0) -> onnx_pb.ModelProto:
path: str, filename_prefix: str, use_cuda: bool = True, device: int = 0,
user_onnx_libs: List[str] = None) -> onnx_pb.ModelProto:
"""
Returns model with optimized weight rounding of every module (Conv and Linear) and also saves the
corresponding quantization encodings to a separate JSON-formatted file that can then be imported by
Expand All @@ -174,6 +178,7 @@ def _apply_adaround(cls, quant_sim: QuantizationSimModel, model: onnx_pb.ModelPr
:param filename_prefix: Prefix to use for filename of the encodings file
:param use_cuda: If we should use cuda
:param device: CUDA device ID
:param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries
:return: Model with Adarounded weights and saves corresponding parameter encodings JSON file at provided path
"""

Expand All @@ -184,7 +189,7 @@ def _apply_adaround(cls, quant_sim: QuantizationSimModel, model: onnx_pb.ModelPr
# Get the module - activation function pair using ConnectedGraph
module_act_func_pair = get_module_act_func_pair(model)

cls._adaround_model(model, quant_sim, module_act_func_pair, params, use_cuda, device)
cls._adaround_model(model, quant_sim, module_act_func_pair, params, use_cuda, device, user_onnx_libs)

# Export quantization encodings to JSON-formatted file
cls._export_encodings_to_json(path, filename_prefix, quant_sim)
Expand All @@ -195,7 +200,7 @@ def _apply_adaround(cls, quant_sim: QuantizationSimModel, model: onnx_pb.ModelPr

@classmethod
def _adaround_model(cls, model: onnx_pb.ModelProto, quant_sim: QuantizationSimModel, module_act_func_pair: Dict,
params: AdaroundParameters, use_cuda: bool = True, device: int = 0):
params: AdaroundParameters, use_cuda: bool = True, device: int = 0, user_onnx_libs: List[str] = None):
"""
Optimize weight rounding of every module (AdaroundSupportedModules) of model in sequential manner
based on occurrence
Expand All @@ -207,6 +212,7 @@ def _adaround_model(cls, model: onnx_pb.ModelProto, quant_sim: QuantizationSimMo
:param params: Adaround parameters
:param use_cuda: If we should use cuda
:param device: CUDA device ID
:param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries
"""
# pylint: disable=too-many-locals, protected-access

Expand Down Expand Up @@ -246,7 +252,7 @@ def _adaround_model(cls, model: onnx_pb.ModelProto, quant_sim: QuantizationSimMo
AdaroundOptimizer.adaround_module(model_data.module_to_info[name], quantized_input_name,
model, quant_sim.model, act_func,
cached_dataset, opt_params, param_to_tensor_quantizer_dict,
use_cuda, device)
use_cuda, device, user_onnx_libs)

finally:
if os.path.exists(WORKING_DIR):
Expand Down
95 changes: 67 additions & 28 deletions TrainingExtensions/onnx/test/python/test_adaround_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,16 @@
# =============================================================================

""" Unit tests for Adaround Weights """
import os
import json
from packaging import version
import numpy as np
import torch
from onnxruntime import SessionOptions, GraphOptimizationLevel, InferenceSession
import pytest

from aimet_common import libquant_info

from aimet_onnx.adaround.adaround_weight import Adaround, AdaroundParameters
import models.models_for_tests as test_models

Expand All @@ -53,64 +56,100 @@ class TestAdaround:

@pytest.mark.skipif(not torch.cuda.is_available(), reason="This unit-test is meant to be run on GPU")
def test_apply_adaround(self):
np.random.seed(0)
torch.manual_seed(0)
model = test_models.single_residual_model()
data_loader = dataloader(input_shape=(1, 3, 32, 32))
dummy_input = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)}
sess = build_session(model, None)
out_before_ada = sess.run(None, dummy_input)
def callback(session, args):
in_tensor = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)}
session.run(None, in_tensor)

params = AdaroundParameters(data_loader=data_loader, num_batches=1, default_num_iterations=5, forward_fn=callback,
forward_pass_callback_args=None)
ada_rounded_model = Adaround.apply_adaround(model, params, './', 'dummy')
sess = build_session(ada_rounded_model, None)
out_after_ada = sess.run(None, dummy_input)
assert not np.array_equal(out_before_ada[0], out_after_ada[0])

with open('./dummy.encodings') as json_file:
encoding_data = json.load(json_file)

param_keys = list(encoding_data.keys())
if version.parse(torch.__version__) >= version.parse("1.13"):
np.random.seed(0)
torch.manual_seed(0)
model = test_models.single_residual_model()
data_loader = dataloader()
dummy_input = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)}
sess = build_session(model)
out_before_ada = sess.run(None, dummy_input)
def callback(session, args):
in_tensor = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)}
session.run(None, in_tensor)

params = AdaroundParameters(data_loader=data_loader, num_batches=1, default_num_iterations=5, forward_fn=callback,
forward_pass_callback_args=None)
ada_rounded_model = Adaround.apply_adaround(model, params, './', 'dummy')
sess = build_session(ada_rounded_model)
out_after_ada = sess.run(None, dummy_input)
assert not np.array_equal(out_before_ada[0], out_after_ada[0])

with open('./dummy.encodings') as json_file:
encoding_data = json.load(json_file)

param_keys = list(encoding_data.keys())
assert 'onnx::Conv_43' in param_keys

@pytest.mark.skipif(not torch.cuda.is_available(), reason="This unit-test is meant to be run on GPU")
def test_apply_adaround_for_custom_op(self):
custom_ops_path = os.path.dirname(libquant_info.__file__)
custom_ops_path = os.path.join(custom_ops_path, "customops")
onnx_library = os.path.join(custom_ops_path, "libonnx_custom_add.so")

np.random.seed(0)
torch.manual_seed(0)
model = test_models.custom_add_model()
data_loader = dataloader(input_shape=(1, 3, 64, 64))
dummy_input = {'input': np.random.rand(1, 3, 64, 64).astype(np.float32)}
sess = build_session(model, [onnx_library])
out_before_ada = sess.run(None, dummy_input)
def callback(session, args):
in_tensor = {'input': np.random.rand(1, 3, 64, 64).astype(np.float32)}
session.run(None, in_tensor)

def dataloader():
params = AdaroundParameters(data_loader=data_loader, num_batches=1, default_num_iterations=5, forward_fn=callback,
forward_pass_callback_args=None)
ada_rounded_model = Adaround.apply_adaround(model, params, './', 'dummy', user_onnx_libs=[onnx_library])
sess = build_session(ada_rounded_model, [onnx_library])
out_after_ada = sess.run(None, dummy_input)
assert not np.array_equal(out_before_ada[0], out_after_ada[0])

with open('./dummy.encodings') as json_file:
encoding_data = json.load(json_file)

param_keys = list(encoding_data.keys())
if version.parse(torch.__version__) >= version.parse("1.13"):
assert 'conv.weight' in param_keys


def dataloader(input_shape: tuple):
class DataLoader:
"""
Example of a Dataloader which can be used for running AMPv2
"""
def __init__(self, batch_size: int):
def __init__(self, batch_size: int, input_shape: tuple):
"""
:param batch_size: batch size for data loader
"""
self.batch_size = batch_size
self.input_shape = input_shape

def __iter__(self):
"""Iterates over dataset"""
dummy_input = np.random.rand(1, 3, 32, 32).astype(np.float32)
dummy_input = np.random.rand(*self.input_shape).astype(np.float32)
yield dummy_input

def __len__(self):
return 4

dummy_dataloader = DataLoader(batch_size=2)
dummy_dataloader = DataLoader(batch_size=2, input_shape=input_shape)
return dummy_dataloader

def build_session(model):

def build_session(model, user_onnx_libs):
"""
Build and return onnxruntime inference session
:param providers: providers to execute onnxruntime
"""
sess_options = SessionOptions()
sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
if user_onnx_libs is not None:
for lib in user_onnx_libs:
sess_options.register_custom_ops_library(lib)
session = InferenceSession(
path_or_bytes=model.model.SerializeToString(),
sess_options=sess_options,
providers=['CPUExecutionProvider'],
)
return session
return session
Loading
Loading