From 4d6d8ef9d4c211454a64d720fcd3c1abee75b138 Mon Sep 17 00:00:00 2001 From: Raj Gite Date: Thu, 23 Nov 2023 13:33:24 +0530 Subject: [PATCH 1/2] Add custom-op support in ONNX AdaRound Signed-off-by: Raj Gite --- .../aimet_onnx/adaround/activation_sampler.py | 13 +++++++----- .../aimet_onnx/adaround/adaround_optimizer.py | 16 ++++++++------- .../aimet_onnx/adaround/adaround_weight.py | 20 ++++++++++++------- .../onnx/test/python/test_quantsim.py | 16 +++++++++------ 4 files changed, 40 insertions(+), 25 deletions(-) diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/activation_sampler.py b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/activation_sampler.py index 86df57e60af..e3c0cc94431 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/activation_sampler.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/activation_sampler.py @@ -64,7 +64,7 @@ class ActivationSampler: """ def __init__(self, orig_op: str, quant_op: str, orig_model: ModelProto, quant_model: QuantizationSimModel, use_cuda: bool, - device: int = 0): + device: int = 0, user_onnx_libs: List[str] = None): """ :param orig_op: Single un quantized op from the original session :param quant_op: Corresponding quant op from the Quant sim session @@ -72,6 +72,7 @@ def __init__(self, orig_op: str, quant_op: str, :param quant_model: Session with the model with quantization simulations ops :param use_cuda: If we should use cuda :param device: CUDA device ID + :param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries :return: Input data to quant op, Output data from original op """ self._org_model = orig_model @@ -84,8 +85,8 @@ def __init__(self, orig_op: str, quant_op: str, else: self.providers = ['CPUExecutionProvider'] - self._orig_module_collector = ModuleData(orig_model, orig_op, self.providers) - self._quant_module_collector = ModuleData(quant_model, quant_op, self.providers) + self._orig_module_collector = ModuleData(orig_model, orig_op, self.providers, user_onnx_libs) + self._quant_module_collector = ModuleData(quant_model, quant_op, self.providers, user_onnx_libs) def sample_and_place_all_acts_on_cpu(self, dataset) -> Tuple: """ @@ -139,15 +140,17 @@ class ModuleData: Collect input and output data to and from module """ - def __init__(self, model: ModelProto, node_name: str, providers: List): + def __init__(self, model: ModelProto, node_name: str, providers: List, user_onnx_libs: List[str] = None): """ :param session: ONNX session :param node: Module reference :param providers: CPU/GPU execution providers + :param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries """ self._model = model self._module_name = node_name self._providers = providers + self._user_onnx_libs = user_onnx_libs def collect_inp_out_data(self, model_input: Dict[str, List[np.ndarray]], collect_input: bool, collect_output: bool) -> Union[Tuple[None, List], Tuple[List, None]]: @@ -161,7 +164,7 @@ def collect_inp_out_data(self, model_input: Dict[str, List[np.ndarray]], """ handle = add_hook_to_get_activation(self._model.model, self._module_name) - sess = QuantizationSimModel.build_session(self._model.model, self._providers) + sess = QuantizationSimModel.build_session(self._model.model, self._providers, self._user_onnx_libs) outputs = sess.run([self._module_name], model_input) remove_activation_hooks(self._model.model, handle) diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_optimizer.py b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_optimizer.py index 601c73bf922..5f63c8fb464 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_optimizer.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_optimizer.py @@ -37,7 +37,7 @@ """ Adaround optimizer """ -from typing import Union, Tuple, Dict +from typing import Union, Tuple, Dict, List import numpy as np import onnx from onnx import numpy_helper @@ -80,7 +80,7 @@ def adaround_module(cls, module: ModuleInfo, quantized_input_name: str, orig_model: ModelProto, quant_model: QuantizationSimModel, act_func: Union[torch.nn.Module, None], cached_dataset: Dataset, opt_params: AdaroundHyperParameters, param_to_adaround_tensor_quantizer: Dict, - use_cuda: bool, device: int = 0): + use_cuda: bool, device: int = 0, user_onnx_libs: List[str] = None): """ Adaround module @@ -95,12 +95,13 @@ def adaround_module(cls, module: ModuleInfo, quantized_input_name: str, :param param_to_adaround_tensor_quantizer: Param name to adaround tensor quantizer dictionary :param use_cuda: If we should use cuda :param device: CUDA device ID + :param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries """ # pylint: disable=too-many-arguments # Optimize weight rounding cls._optimize_rounding(module, quantized_input_name, orig_model, quant_model, act_func, cached_dataset, - opt_params, param_to_adaround_tensor_quantizer, use_cuda, device) + opt_params, param_to_adaround_tensor_quantizer, use_cuda, device, user_onnx_libs) # After optimization, set the optimized layer's rounding mode to "Hard rounding" param_to_adaround_tensor_quantizer[module.params['weight'].name].use_soft_rounding = False @@ -111,7 +112,7 @@ def _optimize_rounding(cls, module: ModuleInfo, quantized_input_name, orig_model: ModelProto, quant_model: QuantizationSimModel, act_func: Union[None, str], cached_dataset: Dataset, opt_params: AdaroundHyperParameters, param_to_adaround_tensor_quantizer: Dict, - use_cuda: bool, device: int = 0): + use_cuda: bool, device: int = 0, user_onnx_libs: List[str] = None): """ Optimizes the weight rounding of quantized wrapper module :param module: Original module @@ -122,6 +123,7 @@ def _optimize_rounding(cls, module: ModuleInfo, quantized_input_name, :param cached_dataset: Cached dataset :param opt_params: Optimization parameters :param param_to_adaround_tensor_quantizer: Param name to adaround tensor quantizer dictionary + :param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries """ # pylint: disable=too-many-locals, too-many-arguments adaround_quantizer = param_to_adaround_tensor_quantizer[module.params['weight'].name] @@ -144,7 +146,7 @@ def _optimize_rounding(cls, module: ModuleInfo, quantized_input_name, # Check if we can cache intermediate activation data. model_inputs = cached_dataset[0] act_sampler = ActivationSampler(module.outputs[0], quantized_input_name, orig_model, quant_model, - use_cuda, device) + use_cuda, device, user_onnx_libs) inp_data, out_data = act_sampler.sample_acts(create_input_dict(orig_model.model, model_inputs)) inp_data_torch, out_data_torch = torch.from_numpy(inp_data[0]), torch.from_numpy(out_data[0]) use_cache_acts_data = TorchAdaroundOptimizer._can_cache_acts_data(len(cached_dataset), inp_data_torch.shape, @@ -159,11 +161,11 @@ def _optimize_rounding(cls, module: ModuleInfo, quantized_input_name, if use_cache_acts_data and AdaroundOptimizer.enable_caching_acts_data(): logger.debug("Caching intermediate activations data for optimization.") all_inp_data, all_orig_out_data = act_sampler.sample_and_place_all_acts_on_cpu(cached_dataset) - all_inp_data, all_out_data = torch.from_numpy(all_inp_data[0]), \ + all_inp_data, all_orig_out_data = torch.from_numpy(all_inp_data[0]), \ torch.from_numpy(all_orig_out_data[0]) # Try to put all cached activations data on GPU for faster optimization if possible. if use_cuda: - all_inp_data, all_orig_out_data = TorchAdaroundOptimizer._place_cached_acts_data(all_inp_data, all_out_data, + all_inp_data, all_orig_out_data = TorchAdaroundOptimizer._place_cached_acts_data(all_inp_data, all_orig_out_data, torch_device) for iteration in range(opt_params.num_iterations): diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_weight.py b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_weight.py index 32f1c6c4b3f..e3fe476762a 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_weight.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_weight.py @@ -116,7 +116,8 @@ def apply_adaround(cls, model: onnx_pb.ModelProto, params: AdaroundParameters, param_bw_override_list: List[Tuple[str, int]] = None, ignore_quant_ops_list: List[str] = None, default_quant_scheme: QuantScheme = QuantScheme.post_training_tf_enhanced, - default_config_file: str = None, use_cuda: bool = True, device: int = 0) -> onnx_pb.ModelProto: + default_config_file: str = None, use_cuda: bool = True, device: int = 0, + user_onnx_libs: List[str] = None) -> onnx_pb.ModelProto: """ Returns model with optimized weight rounding of every module (Conv and Linear) and also saves the corresponding quantization encodings to a separate JSON-formatted file that can then be imported by @@ -136,6 +137,7 @@ def apply_adaround(cls, model: onnx_pb.ModelProto, params: AdaroundParameters, :param default_config_file: Default configuration file for model quantizers :param use_cuda: If we should use cuda :param device: CUDA device ID + :param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries :return: Model with Adarounded weights and saves corresponding parameter encodings JSON file at provided path """ # pylint: disable=too-many-arguments @@ -144,7 +146,8 @@ def apply_adaround(cls, model: onnx_pb.ModelProto, params: AdaroundParameters, model = ONNXModel(model) quant_sim = QuantizationSimModel(copy.deepcopy(model), quant_scheme=default_quant_scheme, default_param_bw=default_param_bw, - config_file=default_config_file) + config_file=default_config_file, + user_onnx_libs=user_onnx_libs) # For the params in the param_bw_override_list, override the default parameter bitwidths in the QuantSim if param_bw_override_list: @@ -156,11 +159,12 @@ def apply_adaround(cls, model: onnx_pb.ModelProto, params: AdaroundParameters, # Compute only param encodings cls._compute_param_encodings(quant_sim, params) - return cls._apply_adaround(quant_sim, model, params, path, filename_prefix, use_cuda, device) + return cls._apply_adaround(quant_sim, model, params, path, filename_prefix, use_cuda, device, user_onnx_libs) @classmethod def _apply_adaround(cls, quant_sim: QuantizationSimModel, model: onnx_pb.ModelProto, params: AdaroundParameters, - path: str, filename_prefix: str, use_cuda: bool = True, device: int = 0) -> onnx_pb.ModelProto: + path: str, filename_prefix: str, use_cuda: bool = True, device: int = 0, + user_onnx_libs: List[str] = None) -> onnx_pb.ModelProto: """ Returns model with optimized weight rounding of every module (Conv and Linear) and also saves the corresponding quantization encodings to a separate JSON-formatted file that can then be imported by @@ -174,6 +178,7 @@ def _apply_adaround(cls, quant_sim: QuantizationSimModel, model: onnx_pb.ModelPr :param filename_prefix: Prefix to use for filename of the encodings file :param use_cuda: If we should use cuda :param device: CUDA device ID + :param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries :return: Model with Adarounded weights and saves corresponding parameter encodings JSON file at provided path """ @@ -184,7 +189,7 @@ def _apply_adaround(cls, quant_sim: QuantizationSimModel, model: onnx_pb.ModelPr # Get the module - activation function pair using ConnectedGraph module_act_func_pair = get_module_act_func_pair(model) - cls._adaround_model(model, quant_sim, module_act_func_pair, params, use_cuda, device) + cls._adaround_model(model, quant_sim, module_act_func_pair, params, use_cuda, device, user_onnx_libs) # Export quantization encodings to JSON-formatted file cls._export_encodings_to_json(path, filename_prefix, quant_sim) @@ -195,7 +200,7 @@ def _apply_adaround(cls, quant_sim: QuantizationSimModel, model: onnx_pb.ModelPr @classmethod def _adaround_model(cls, model: onnx_pb.ModelProto, quant_sim: QuantizationSimModel, module_act_func_pair: Dict, - params: AdaroundParameters, use_cuda: bool = True, device: int = 0): + params: AdaroundParameters, use_cuda: bool = True, device: int = 0, user_onnx_libs: List[str] = None): """ Optimize weight rounding of every module (AdaroundSupportedModules) of model in sequential manner based on occurrence @@ -207,6 +212,7 @@ def _adaround_model(cls, model: onnx_pb.ModelProto, quant_sim: QuantizationSimMo :param params: Adaround parameters :param use_cuda: If we should use cuda :param device: CUDA device ID + :param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries """ # pylint: disable=too-many-locals, protected-access @@ -246,7 +252,7 @@ def _adaround_model(cls, model: onnx_pb.ModelProto, quant_sim: QuantizationSimMo AdaroundOptimizer.adaround_module(model_data.module_to_info[name], quantized_input_name, model, quant_sim.model, act_func, cached_dataset, opt_params, param_to_tensor_quantizer_dict, - use_cuda, device) + use_cuda, device, user_onnx_libs) finally: if os.path.exists(WORKING_DIR): diff --git a/TrainingExtensions/onnx/test/python/test_quantsim.py b/TrainingExtensions/onnx/test/python/test_quantsim.py index 759eaa7e4d4..c789ae58296 100644 --- a/TrainingExtensions/onnx/test/python/test_quantsim.py +++ b/TrainingExtensions/onnx/test/python/test_quantsim.py @@ -36,6 +36,8 @@ # ============================================================================= import json import os +import shutil + import onnx.numpy_helper import torch import numpy as np @@ -577,14 +579,14 @@ def test_multiple_output_quantsim(self): default_param_bw=8) sim.session.run(None, {'input': sample_input}) - def test_model_with_custom_ops(self): custom_ops_path = os.path.dirname(libquant_info.__file__) custom_ops_path = os.path.join(custom_ops_path, "customops") onnx_library = os.path.join(custom_ops_path, "libonnx_custom_add.so") - def callback(session, args): - pass + def dummy_callback(session, args): + calib_data = {'input': np.random.rand(1, 3, 64, 64).astype(np.float32)} + _ = session.run(None, calib_data) model = custom_add_model() sim = QuantizationSimModel(model=model, @@ -594,8 +596,10 @@ def callback(session, args): user_onnx_libs=[onnx_library]) sim.save_model_graph("./quantized_custom_model") - def dummy_callback(session, args): - pass - sim.compute_encodings(dummy_callback, None) + + os.makedirs('./tmp', exist_ok=True) sim.export('./tmp/', 'custom_op_model') + + if os.path.exists('./tmp'): + shutil.rmtree('./tmp') From c9b954f363720c3c754b52f7edd9f77a6a36dcce Mon Sep 17 00:00:00 2001 From: Raj Gite Date: Mon, 12 Feb 2024 16:57:24 +0530 Subject: [PATCH 2/2] Unit test Signed-off-by: Raj Gite --- .../onnx/test/python/test_adaround_weight.py | 95 +++++++++++++------ 1 file changed, 67 insertions(+), 28 deletions(-) diff --git a/TrainingExtensions/onnx/test/python/test_adaround_weight.py b/TrainingExtensions/onnx/test/python/test_adaround_weight.py index 414c7cc5c56..f159f3a96f4 100644 --- a/TrainingExtensions/onnx/test/python/test_adaround_weight.py +++ b/TrainingExtensions/onnx/test/python/test_adaround_weight.py @@ -36,6 +36,7 @@ # ============================================================================= """ Unit tests for Adaround Weights """ +import os import json from packaging import version import numpy as np @@ -43,6 +44,8 @@ from onnxruntime import SessionOptions, GraphOptimizationLevel, InferenceSession import pytest +from aimet_common import libquant_info + from aimet_onnx.adaround.adaround_weight import Adaround, AdaroundParameters import models.models_for_tests as test_models @@ -53,64 +56,100 @@ class TestAdaround: @pytest.mark.skipif(not torch.cuda.is_available(), reason="This unit-test is meant to be run on GPU") def test_apply_adaround(self): + np.random.seed(0) + torch.manual_seed(0) + model = test_models.single_residual_model() + data_loader = dataloader(input_shape=(1, 3, 32, 32)) + dummy_input = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)} + sess = build_session(model, None) + out_before_ada = sess.run(None, dummy_input) + def callback(session, args): + in_tensor = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)} + session.run(None, in_tensor) + + params = AdaroundParameters(data_loader=data_loader, num_batches=1, default_num_iterations=5, forward_fn=callback, + forward_pass_callback_args=None) + ada_rounded_model = Adaround.apply_adaround(model, params, './', 'dummy') + sess = build_session(ada_rounded_model, None) + out_after_ada = sess.run(None, dummy_input) + assert not np.array_equal(out_before_ada[0], out_after_ada[0]) + + with open('./dummy.encodings') as json_file: + encoding_data = json.load(json_file) + + param_keys = list(encoding_data.keys()) if version.parse(torch.__version__) >= version.parse("1.13"): - np.random.seed(0) - torch.manual_seed(0) - model = test_models.single_residual_model() - data_loader = dataloader() - dummy_input = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)} - sess = build_session(model) - out_before_ada = sess.run(None, dummy_input) - def callback(session, args): - in_tensor = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)} - session.run(None, in_tensor) - - params = AdaroundParameters(data_loader=data_loader, num_batches=1, default_num_iterations=5, forward_fn=callback, - forward_pass_callback_args=None) - ada_rounded_model = Adaround.apply_adaround(model, params, './', 'dummy') - sess = build_session(ada_rounded_model) - out_after_ada = sess.run(None, dummy_input) - assert not np.array_equal(out_before_ada[0], out_after_ada[0]) - - with open('./dummy.encodings') as json_file: - encoding_data = json.load(json_file) - - param_keys = list(encoding_data.keys()) assert 'onnx::Conv_43' in param_keys + @pytest.mark.skipif(not torch.cuda.is_available(), reason="This unit-test is meant to be run on GPU") + def test_apply_adaround_for_custom_op(self): + custom_ops_path = os.path.dirname(libquant_info.__file__) + custom_ops_path = os.path.join(custom_ops_path, "customops") + onnx_library = os.path.join(custom_ops_path, "libonnx_custom_add.so") + + np.random.seed(0) + torch.manual_seed(0) + model = test_models.custom_add_model() + data_loader = dataloader(input_shape=(1, 3, 64, 64)) + dummy_input = {'input': np.random.rand(1, 3, 64, 64).astype(np.float32)} + sess = build_session(model, [onnx_library]) + out_before_ada = sess.run(None, dummy_input) + def callback(session, args): + in_tensor = {'input': np.random.rand(1, 3, 64, 64).astype(np.float32)} + session.run(None, in_tensor) -def dataloader(): + params = AdaroundParameters(data_loader=data_loader, num_batches=1, default_num_iterations=5, forward_fn=callback, + forward_pass_callback_args=None) + ada_rounded_model = Adaround.apply_adaround(model, params, './', 'dummy', user_onnx_libs=[onnx_library]) + sess = build_session(ada_rounded_model, [onnx_library]) + out_after_ada = sess.run(None, dummy_input) + assert not np.array_equal(out_before_ada[0], out_after_ada[0]) + + with open('./dummy.encodings') as json_file: + encoding_data = json.load(json_file) + + param_keys = list(encoding_data.keys()) + if version.parse(torch.__version__) >= version.parse("1.13"): + assert 'conv.weight' in param_keys + + +def dataloader(input_shape: tuple): class DataLoader: """ Example of a Dataloader which can be used for running AMPv2 """ - def __init__(self, batch_size: int): + def __init__(self, batch_size: int, input_shape: tuple): """ :param batch_size: batch size for data loader """ self.batch_size = batch_size + self.input_shape = input_shape def __iter__(self): """Iterates over dataset""" - dummy_input = np.random.rand(1, 3, 32, 32).astype(np.float32) + dummy_input = np.random.rand(*self.input_shape).astype(np.float32) yield dummy_input def __len__(self): return 4 - dummy_dataloader = DataLoader(batch_size=2) + dummy_dataloader = DataLoader(batch_size=2, input_shape=input_shape) return dummy_dataloader -def build_session(model): + +def build_session(model, user_onnx_libs): """ Build and return onnxruntime inference session :param providers: providers to execute onnxruntime """ sess_options = SessionOptions() sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL + if user_onnx_libs is not None: + for lib in user_onnx_libs: + sess_options.register_custom_ops_library(lib) session = InferenceSession( path_or_bytes=model.model.SerializeToString(), sess_options=sess_options, providers=['CPUExecutionProvider'], ) - return session \ No newline at end of file + return session