diff --git a/NightlyTests/onnx/test_rnn_quantsim.py b/NightlyTests/onnx/test_rnn_quantsim.py new file mode 100644 index 00000000000..31ac1755566 --- /dev/null +++ b/NightlyTests/onnx/test_rnn_quantsim.py @@ -0,0 +1,130 @@ +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +import os + +import numpy as np +import pytest +import torch +from onnx import load_model +from torchaudio import models + +from aimet_onnx.utils import make_dummy_input +from aimet_common.defs import QuantScheme, QuantizationDataType +from aimet_onnx.quantsim import QuantizationSimModel +from aimet_common.quantsim_config.utils import get_path_for_per_channel_config +from torch_utils import get_librispeech_data_loaders, train_librispeech + +WORKING_DIR = '/tmp/quantsim' + +batch_size = 64 +n_feature = 128 +n_class = 29 + + +def model_eval_onnx(session, val_loader, max_batches): + """ + :param model: model to be evaluated + :param val_loader: dataloader for validation data + :return: CTC Loss on validation data + """ + + test_loss = 0 + for (i, batch) in enumerate(val_loader): + spectrograms, labels, input_lengths, label_lengths = batch + x = spectrograms.numpy() + + in_tensor = {'input': x} + out = session.run(None, in_tensor)[0] + + out = torch.Tensor(out).transpose(0, 1) + criterion = torch.nn.CTCLoss(blank=28) + loss = criterion(out, labels, input_lengths, label_lengths) + test_loss += loss.item() / len(val_loader) + + if i+1 >= max_batches: + break + + print(f'Test loss: {test_loss}') + return test_loss + + +class TestQuantizeAcceptance: + """ Acceptance test for AIMET ONNX """ + @pytest.mark.parametrize("config_file", [None, get_path_for_per_channel_config()]) + @pytest.mark.cuda + def test_quantized_accuracy(self, config_file): + if not os.path.exists(WORKING_DIR): + os.makedirs(WORKING_DIR) + np.random.seed(0) + torch.manual_seed(0) + model = models.DeepSpeech(n_feature=n_feature, n_class=n_class) + if torch.cuda.is_available(): + device = torch.device('cuda:0') + model.to(device) + + train_librispeech(model, 1, max_batches=30) + + train_loader, val_loader = get_librispeech_data_loaders(batch_size=batch_size, drop_last=False) + + torch.onnx.export(model, torch.rand(1, 1, 1, 128).cuda(), os.path.join(WORKING_DIR, 'deepspeech.onnx'), + training=torch.onnx.TrainingMode.PRESERVE, + input_names=['input'], output_names=['output'], + dynamic_axes={ + 'input': {0: 'batch_size', 2: 'time'}, + 'output': {0: 'batch_size', 1: 'time'}, + } + ) + + onnx_model = load_model(os.path.join(WORKING_DIR, 'deepspeech.onnx')) + dummy_input = make_dummy_input(onnx_model) + sim = QuantizationSimModel(onnx_model, dummy_input, quant_scheme=QuantScheme.post_training_tf, default_param_bw=8, + default_activation_bw=8, use_cuda=True, config_file=config_file) + + def onnx_callback(session, iters): + for i, batch in enumerate(train_loader): + x = batch[0].detach().cpu().numpy() + in_tensor = {'input': x} + session.run(None, in_tensor) + print(i, '/', iters) + if i+1 >= iters: + break + + sim.compute_encodings(onnx_callback, 1) + + onnx_qs_test_loss = model_eval_onnx(sim.session, val_loader, max_batches=1) + + assert onnx_qs_test_loss < 0.1 diff --git a/NightlyTests/onnx/torch_utils.py b/NightlyTests/onnx/torch_utils.py index 5c000f2396a..e8eb8d2fb2c 100644 --- a/NightlyTests/onnx/torch_utils.py +++ b/NightlyTests/onnx/torch_utils.py @@ -2,7 +2,7 @@ # ============================================================================= # @@-COPYRIGHT-START-@@ # -# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2023-2024, Qualcomm Innovation Center, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: @@ -39,10 +39,93 @@ import torch.nn as nn import torch.optim as optim import torchvision +import torchaudio from torch.optim import lr_scheduler from torch.utils.data import DataLoader +class TextTransform: + """Maps characters to integers and vice versa""" + def __init__(self): + char_map_str = """ + ' 0 + 1 + a 2 + b 3 + c 4 + d 5 + e 6 + f 7 + g 8 + h 9 + i 10 + j 11 + k 12 + l 13 + m 14 + n 15 + o 16 + p 17 + q 18 + r 19 + s 20 + t 21 + u 22 + v 23 + w 24 + x 25 + y 26 + z 27 + """ + self.char_map = {} + for line in char_map_str.strip().split('\n'): + ch, index = line.split() + self.char_map[ch] = int(index) + + def text_to_int(self, text): + """ Use a character map and convert text to an integer array """ + int_sequence = [] + for c in text: + if c == ' ': + ch = self.char_map[''] + else: + ch = self.char_map[c] + int_sequence.append(ch) + return int_sequence + + +def librispeech_data_processing(data, data_type="train"): + assert (data_type=='train' or data_type=='valid'), "data_type needs to be train/valid" + spectrograms = [] + labels = [] + input_lengths = [] + label_lengths = [] + train_audio_transforms = nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.FrequencyMasking(freq_mask_param=15), + torchaudio.transforms.TimeMasking(time_mask_param=35) + ) + + valid_audio_transforms = torchaudio.transforms.MelSpectrogram() + + text_transform = TextTransform() + for (waveform, _, utterance, _, _, _) in data: + if data_type == 'train': + spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1) + elif data_type == 'valid': + spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1) + spectrograms.append(spec) + label = torch.Tensor(text_transform.text_to_int(utterance.lower())) + labels.append(label) + input_lengths.append(spec.shape[0]//2) + label_lengths.append(len(label)) + + spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1) + labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) + + return spectrograms, labels, input_lengths, label_lengths + + def get_cifar10_data_loaders(batch_size=64, num_workers=4, drop_last=True): train_set = torchvision.datasets.CIFAR10("./data/CIFAR10", train=True, download=True, transform=torchvision.transforms.ToTensor()) @@ -54,6 +137,17 @@ def get_cifar10_data_loaders(batch_size=64, num_workers=4, drop_last=True): return train_loader, val_loader +def get_librispeech_data_loaders(batch_size=64, num_workers=4, drop_last=True): + train_set = torchaudio.datasets.LIBRISPEECH("./data/LIBRISPEECH", url='train-clean-100', download=True) + val_set = torchaudio.datasets.LIBRISPEECH("./data/LIBRISPEECH", url='dev-clean', download=True) + train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, + collate_fn=lambda x: librispeech_data_processing(x, 'train')) + val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, + collate_fn=lambda x: librispeech_data_processing(x, 'valid')) + return train_loader, val_loader + + + def model_train(model: torch.nn.Module, train_loader: DataLoader, epochs: int, optimizer: optim.Optimizer, scheduler): """ Trains the given torch model for the specified number of epochs @@ -99,6 +193,41 @@ def train_cifar10(model: torch.nn.Module, epochs): model_train(model, train_loader, epochs, optimizer, scheduler) +def train_librispeech(model: torch.nn.Module, epochs, max_batches): + """ + Trains a PyTorch model on LIBRISPEECH for the specified number of epochs + + :param model: PyTorch model to train + :param epochs: Number of epochs to train + """ + use_cuda = next(model.parameters()).is_cuda + model.train() + if use_cuda: + device = torch.device('cuda:0') + else: + device = torch.device('cpu') + train_loader, _ = get_librispeech_data_loaders(batch_size=16) + lr = 0.01 + steps = int(len(train_loader)) + optimizer = optim.Adam(model.parameters(), lr=lr) + criterion = nn.CTCLoss(blank=28).to(device) + scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=steps, epochs=epochs) + + for epoch in range(epochs): + for (i, batch) in enumerate(train_loader): + spectrograms, labels, input_lengths, label_lengths = batch + spectrograms, labels = spectrograms.to(device), labels.to(device) + optimizer.zero_grad() + output = model(spectrograms) + output = output.transpose(0, 1) + loss = criterion(output, labels, input_lengths, label_lengths) + loss.backward() + optimizer.step() + scheduler.step() + if i+1 >= max_batches: + break + + def model_eval_torch(model: torch.nn.Module, val_loader: DataLoader): """ Measures the accuracy of a PyTorch model over a given validation dataset diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/meta/connectedgraph.py b/TrainingExtensions/onnx/src/python/aimet_onnx/meta/connectedgraph.py index b4a9e9377a7..d90768159c9 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/meta/connectedgraph.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/meta/connectedgraph.py @@ -3,7 +3,7 @@ # # @@-COPYRIGHT-START-@@ # -# Copyright (c) 2022, Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2022-2024, Qualcomm Innovation Center, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: @@ -64,11 +64,13 @@ logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.ConnectedGraph) +INPUT_INDEX = 0 WEIGHT_INDEX = 1 BIAS_INDEX = 2 +RECURRENT_WEIGHT_INDEX = 2 RUNNING_MEAN_INDEX = 3 RUNNING_VAR_INDEX = 4 -OPS_WITH_PARAMS = ["Conv", "Gemm", "ConvTranspose", "BatchNormalization", "MatMul"] +OPS_WITH_PARAMS = ["Conv", "Gemm", "ConvTranspose", "BatchNormalization", "MatMul", "RNN", "LSTM", "GRU"] CONSTANT_TYPE = ['Constant', 'ConstantOfShape'] @@ -254,6 +256,8 @@ def check_if_param(node: NodeProto, index: int) -> bool: return True if node.op_type == 'BatchNormalization' and index in [WEIGHT_INDEX, BIAS_INDEX, RUNNING_VAR_INDEX, RUNNING_MEAN_INDEX]: return True + if node.op_type in ['RNN', 'LSTM', 'GRU'] and index != INPUT_INDEX: + return True return False @@ -560,6 +564,19 @@ def create_matmul_params(my_op: Op): if weight_tensor: create_and_connect_product(weight_tensor.name, weight_tensor.dims, my_op, weight_tensor, 'weight') + def create_recurrent_type_params(my_op: Op): + """ + Create products for RNN, LSTM and GRU layer + + :param my_op: Connected Graph Op + """ + op = my_op.get_module() + weight_tensor = ParamUtils.get_param(self.model, op, WEIGHT_INDEX) + create_and_connect_product(weight_tensor.name, weight_tensor.dims, my_op, weight_tensor, 'weight_x') + + recurrent_weight_tensor = ParamUtils.get_param(self.model, op, RECURRENT_WEIGHT_INDEX) + create_and_connect_product(recurrent_weight_tensor.name, recurrent_weight_tensor.dims, my_op, recurrent_weight_tensor, 'weight_r') + def create_batchnorm_params(my_op: Op): """ Create products for fusedbatchnorm """ op = my_op.get_module() @@ -590,6 +607,9 @@ def handle_default(my_op: Op): "Conv": create_conv2d_dense_type_params, "Gemm": create_conv2d_dense_type_params, "ConvTranspose": create_conv2d_dense_type_params, + "RNN": create_recurrent_type_params, + "LSTM": create_recurrent_type_params, + "GRU": create_recurrent_type_params, "BatchNormalization": create_batchnorm_params, "InstanceNormalization": create_batchnorm_params, "MatMul": create_matmul_params diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/utils.py b/TrainingExtensions/onnx/src/python/aimet_onnx/utils.py index a67c8ecd777..d9ecf13120d 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/utils.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/utils.py @@ -2,7 +2,7 @@ # ============================================================================= # @@-COPYRIGHT-START-@@ # -# Copyright (c) 2022, Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2022-2024, Qualcomm Innovation Center, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: @@ -57,7 +57,7 @@ OP_TYPES_WITH_PARAMS = ['Conv', 'Gemm', 'ConvTranspose', 'BatchNormalization', 'MatMul', 'Transpose', - 'InstanceNormalization'] + 'InstanceNormalization', 'RNN', 'LSTM', 'GRU'] def remove_nodes_with_type(node_type: str, onnx_graph: onnx.GraphProto): @@ -203,6 +203,8 @@ def make_dummy_input(model: ModelProto, dynamic_size: int = 1) -> Dict[str, np.n shape.append(dim.dim_value) if shape: input_dict[name] = np.random.randn(*shape).astype(mapping.TENSOR_TYPE_TO_NP_TYPE[dtype]) + else: + input_dict[name] = np.array(np.random.randn(*shape)).astype(mapping.TENSOR_TYPE_TO_NP_TYPE[dtype]) return input_dict @@ -297,7 +299,7 @@ def get_graph_intermediate_activations(graph: GraphProto) -> List[str]: activation_names = [] for node in graph.node: for name in node.input: - if name not in activation_names and name not in param_names: + if name not in activation_names and name not in param_names and name: activation_names.append(name) return activation_names diff --git a/TrainingExtensions/onnx/test/python/models/models_for_tests.py b/TrainingExtensions/onnx/test/python/models/models_for_tests.py index dbd201e254b..bf133fdddcb 100644 --- a/TrainingExtensions/onnx/test/python/models/models_for_tests.py +++ b/TrainingExtensions/onnx/test/python/models/models_for_tests.py @@ -2,7 +2,7 @@ # ============================================================================= # @@-COPYRIGHT-START-@@ # -# Copyright (c) 2022-2023, Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2022-2024, Qualcomm Innovation Center, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: @@ -1018,6 +1018,46 @@ def build_dummy_model(): return model + +def build_lstm_gru_dummy_model(): + op = OperatorSetIdProto() + op.version = 13 + + input_info = helper.make_tensor_value_info(name='input', elem_type=TensorProto.FLOAT, + shape=[1, 8, 64]) + output_info = helper.make_tensor_value_info(name='output', elem_type=TensorProto.FLOAT, + shape=[1, 1, 8, 16]) + + lstm_node = helper.make_node('LSTM', + ['input', 'lstm_w', 'lstm_r_w'], + ['2'], + 'lstm', + hidden_size=16) + squeeze_node = helper.make_node('Squeeze', + ['2', 'axis'], + ['3'], + 'squeeze') + gru_node = helper.make_node('GRU', + ['3', 'gru_w', 'gru_r_w'], + ['output'], + 'gru', + hidden_size=16) + + lstm_w_init = numpy_helper.from_array(np.random.rand(1, 64, 64).astype(np.float32), 'lstm_w') + lstm_r_w_init = numpy_helper.from_array(np.random.rand(1, 64, 16).astype(np.float32), 'lstm_r_w') + squeeze_axis_init = numpy_helper.from_array(np.array([1]).astype(np.int64), 'axis') + gru_w_init = numpy_helper.from_array(np.random.rand(1, 48, 16).astype(np.float32), 'gru_w') + gru_r_w_init = numpy_helper.from_array(np.random.rand(1, 48, 16).astype(np.float32), 'gru_r_w') + + onnx_graph = helper.make_graph([lstm_node, squeeze_node, gru_node], + 'dummy_graph', [input_info], [output_info], + [lstm_w_init, lstm_r_w_init, squeeze_axis_init, gru_w_init, gru_r_w_init]) + + model = helper.make_model(onnx_graph, opset_imports=[op]) + + return model + + def single_residual_model(training=torch.onnx.TrainingMode.EVAL): x = torch.randn(1, 3, 32, 32, requires_grad=True) model = SingleResidualWithAvgPool() diff --git a/TrainingExtensions/onnx/test/python/test_quantsim.py b/TrainingExtensions/onnx/test/python/test_quantsim.py index e5ec6bc8a9d..759eaa7e4d4 100644 --- a/TrainingExtensions/onnx/test/python/test_quantsim.py +++ b/TrainingExtensions/onnx/test/python/test_quantsim.py @@ -2,7 +2,7 @@ # ============================================================================= # @@-COPYRIGHT-START-@@ # -# Copyright (c) 2022-2023, Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2022-2024, Qualcomm Innovation Center, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: @@ -51,7 +51,7 @@ from aimet_onnx.qc_quantize_op import OpMode from aimet_onnx.utils import make_dummy_input from models.models_for_tests import SingleResidual -from models.models_for_tests import build_dummy_model, single_residual_model, BNAfterConv, multi_input_with_constant_model , multi_output_model, custom_add_model +from models.models_for_tests import build_dummy_model, single_residual_model, BNAfterConv, multi_input_with_constant_model , multi_output_model, custom_add_model, build_lstm_gru_dummy_model class DummyModel(SingleResidual): @@ -223,6 +223,43 @@ def dummy_callback(session, args): param_encodings_keys = list(encoding_data["param_encodings"][param][0].keys()) assert param_encodings_keys == ['bitwidth', 'dtype', 'is_symmetric', 'max', 'min', 'offset', 'scale'] + def test_lstm_gru(self): + """Test for LSTM and GRU dummy model""" + model = build_lstm_gru_dummy_model() + sim = QuantizationSimModel(model) + + for quantizer in sim.qc_quantize_op_dict: + sim.qc_quantize_op_dict[quantizer].enabled = True + + def callback(session, args): + in_tensor = {'input': np.random.rand(1, 8, 64).astype(np.float32)} + session.run(None, in_tensor) + + sim.compute_encodings(callback, None) + + for name, qc_op in sim.get_qc_quantize_op().items(): + assert qc_op.encodings[0].bw == 8 + + for name, qc_op in sim.get_qc_quantize_op().items(): + assert qc_op.quant_info.tensorQuantizerRef[0].isEncodingValid is True + assert qc_op.op_mode == OpMode.quantizeDequantize + + sim.export('/tmp/', 'quant_sim_model') + + with open('/tmp/quant_sim_model.encodings', 'rb') as json_file: + encoding_data = json.load(json_file) + activation_keys = list(encoding_data["activation_encodings"].keys()) + assert activation_keys == ['2', 'input', 'output'] + for act in activation_keys: + act_encodings_keys = list(encoding_data["activation_encodings"][act][0].keys()) + assert act_encodings_keys == ['bitwidth', 'dtype', 'is_symmetric', 'max', 'min', 'offset', 'scale'] + + param_keys = list(encoding_data['param_encodings'].keys()) + assert param_keys == ['gru_r_w', 'gru_w', 'lstm_r_w', 'lstm_w'] + for param in param_keys: + param_encodings_keys = list(encoding_data["param_encodings"][param][0].keys()) + assert param_encodings_keys == ['bitwidth', 'dtype', 'is_symmetric', 'max', 'min', 'offset', 'scale'] + def test_single_residual(self): if version.parse(torch.__version__) >= version.parse("1.13"): model = single_residual_model().model