Skip to content

Commit

Permalink
Add RNN/LSTM/GRU support for ONNX QuantSim (#2685)
Browse files Browse the repository at this point in the history
Signed-off-by: Jokay Su <quic_chenzhen@quicinc.com>
  • Loading branch information
quic-chenzhen authored Feb 13, 2024
1 parent 6263b42 commit 56b36f4
Show file tree
Hide file tree
Showing 6 changed files with 367 additions and 9 deletions.
130 changes: 130 additions & 0 deletions NightlyTests/onnx/test_rnn_quantsim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
import os

import numpy as np
import pytest
import torch
from onnx import load_model
from torchaudio import models

from aimet_onnx.utils import make_dummy_input
from aimet_common.defs import QuantScheme, QuantizationDataType
from aimet_onnx.quantsim import QuantizationSimModel
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from torch_utils import get_librispeech_data_loaders, train_librispeech

WORKING_DIR = '/tmp/quantsim'

batch_size = 64
n_feature = 128
n_class = 29


def model_eval_onnx(session, val_loader, max_batches):
"""
:param model: model to be evaluated
:param val_loader: dataloader for validation data
:return: CTC Loss on validation data
"""

test_loss = 0
for (i, batch) in enumerate(val_loader):
spectrograms, labels, input_lengths, label_lengths = batch
x = spectrograms.numpy()

in_tensor = {'input': x}
out = session.run(None, in_tensor)[0]

out = torch.Tensor(out).transpose(0, 1)
criterion = torch.nn.CTCLoss(blank=28)
loss = criterion(out, labels, input_lengths, label_lengths)
test_loss += loss.item() / len(val_loader)

if i+1 >= max_batches:
break

print(f'Test loss: {test_loss}')
return test_loss


class TestQuantizeAcceptance:
""" Acceptance test for AIMET ONNX """
@pytest.mark.parametrize("config_file", [None, get_path_for_per_channel_config()])
@pytest.mark.cuda
def test_quantized_accuracy(self, config_file):
if not os.path.exists(WORKING_DIR):
os.makedirs(WORKING_DIR)
np.random.seed(0)
torch.manual_seed(0)
model = models.DeepSpeech(n_feature=n_feature, n_class=n_class)
if torch.cuda.is_available():
device = torch.device('cuda:0')
model.to(device)

train_librispeech(model, 1, max_batches=30)

train_loader, val_loader = get_librispeech_data_loaders(batch_size=batch_size, drop_last=False)

torch.onnx.export(model, torch.rand(1, 1, 1, 128).cuda(), os.path.join(WORKING_DIR, 'deepspeech.onnx'),
training=torch.onnx.TrainingMode.PRESERVE,
input_names=['input'], output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size', 2: 'time'},
'output': {0: 'batch_size', 1: 'time'},
}
)

onnx_model = load_model(os.path.join(WORKING_DIR, 'deepspeech.onnx'))
dummy_input = make_dummy_input(onnx_model)
sim = QuantizationSimModel(onnx_model, dummy_input, quant_scheme=QuantScheme.post_training_tf, default_param_bw=8,
default_activation_bw=8, use_cuda=True, config_file=config_file)

def onnx_callback(session, iters):
for i, batch in enumerate(train_loader):
x = batch[0].detach().cpu().numpy()
in_tensor = {'input': x}
session.run(None, in_tensor)
print(i, '/', iters)
if i+1 >= iters:
break

sim.compute_encodings(onnx_callback, 1)

onnx_qs_test_loss = model_eval_onnx(sim.session, val_loader, max_batches=1)

assert onnx_qs_test_loss < 0.1
131 changes: 130 additions & 1 deletion NightlyTests/onnx/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2023-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -39,10 +39,93 @@
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchaudio
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader


class TextTransform:
"""Maps characters to integers and vice versa"""
def __init__(self):
char_map_str = """
' 0
<SPACE> 1
a 2
b 3
c 4
d 5
e 6
f 7
g 8
h 9
i 10
j 11
k 12
l 13
m 14
n 15
o 16
p 17
q 18
r 19
s 20
t 21
u 22
v 23
w 24
x 25
y 26
z 27
"""
self.char_map = {}
for line in char_map_str.strip().split('\n'):
ch, index = line.split()
self.char_map[ch] = int(index)

def text_to_int(self, text):
""" Use a character map and convert text to an integer array """
int_sequence = []
for c in text:
if c == ' ':
ch = self.char_map['<SPACE>']
else:
ch = self.char_map[c]
int_sequence.append(ch)
return int_sequence


def librispeech_data_processing(data, data_type="train"):
assert (data_type=='train' or data_type=='valid'), "data_type needs to be train/valid"
spectrograms = []
labels = []
input_lengths = []
label_lengths = []
train_audio_transforms = nn.Sequential(
torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
torchaudio.transforms.FrequencyMasking(freq_mask_param=15),
torchaudio.transforms.TimeMasking(time_mask_param=35)
)

valid_audio_transforms = torchaudio.transforms.MelSpectrogram()

text_transform = TextTransform()
for (waveform, _, utterance, _, _, _) in data:
if data_type == 'train':
spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1)
elif data_type == 'valid':
spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1)
spectrograms.append(spec)
label = torch.Tensor(text_transform.text_to_int(utterance.lower()))
labels.append(label)
input_lengths.append(spec.shape[0]//2)
label_lengths.append(len(label))

spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1)
labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)

return spectrograms, labels, input_lengths, label_lengths


def get_cifar10_data_loaders(batch_size=64, num_workers=4, drop_last=True):
train_set = torchvision.datasets.CIFAR10("./data/CIFAR10", train=True, download=True,
transform=torchvision.transforms.ToTensor())
Expand All @@ -54,6 +137,17 @@ def get_cifar10_data_loaders(batch_size=64, num_workers=4, drop_last=True):
return train_loader, val_loader


def get_librispeech_data_loaders(batch_size=64, num_workers=4, drop_last=True):
train_set = torchaudio.datasets.LIBRISPEECH("./data/LIBRISPEECH", url='train-clean-100', download=True)
val_set = torchaudio.datasets.LIBRISPEECH("./data/LIBRISPEECH", url='dev-clean', download=True)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers,
collate_fn=lambda x: librispeech_data_processing(x, 'train'))
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers,
collate_fn=lambda x: librispeech_data_processing(x, 'valid'))
return train_loader, val_loader



def model_train(model: torch.nn.Module, train_loader: DataLoader, epochs: int, optimizer: optim.Optimizer, scheduler):
"""
Trains the given torch model for the specified number of epochs
Expand Down Expand Up @@ -99,6 +193,41 @@ def train_cifar10(model: torch.nn.Module, epochs):
model_train(model, train_loader, epochs, optimizer, scheduler)


def train_librispeech(model: torch.nn.Module, epochs, max_batches):
"""
Trains a PyTorch model on LIBRISPEECH for the specified number of epochs
:param model: PyTorch model to train
:param epochs: Number of epochs to train
"""
use_cuda = next(model.parameters()).is_cuda
model.train()
if use_cuda:
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
train_loader, _ = get_librispeech_data_loaders(batch_size=16)
lr = 0.01
steps = int(len(train_loader))
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CTCLoss(blank=28).to(device)
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=steps, epochs=epochs)

for epoch in range(epochs):
for (i, batch) in enumerate(train_loader):
spectrograms, labels, input_lengths, label_lengths = batch
spectrograms, labels = spectrograms.to(device), labels.to(device)
optimizer.zero_grad()
output = model(spectrograms)
output = output.transpose(0, 1)
loss = criterion(output, labels, input_lengths, label_lengths)
loss.backward()
optimizer.step()
scheduler.step()
if i+1 >= max_batches:
break


def model_eval_torch(model: torch.nn.Module, val_loader: DataLoader):
"""
Measures the accuracy of a PyTorch model over a given validation dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2022, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2022-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -64,11 +64,13 @@

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.ConnectedGraph)

INPUT_INDEX = 0
WEIGHT_INDEX = 1
BIAS_INDEX = 2
RECURRENT_WEIGHT_INDEX = 2
RUNNING_MEAN_INDEX = 3
RUNNING_VAR_INDEX = 4
OPS_WITH_PARAMS = ["Conv", "Gemm", "ConvTranspose", "BatchNormalization", "MatMul"]
OPS_WITH_PARAMS = ["Conv", "Gemm", "ConvTranspose", "BatchNormalization", "MatMul", "RNN", "LSTM", "GRU"]
CONSTANT_TYPE = ['Constant', 'ConstantOfShape']


Expand Down Expand Up @@ -254,6 +256,8 @@ def check_if_param(node: NodeProto, index: int) -> bool:
return True
if node.op_type == 'BatchNormalization' and index in [WEIGHT_INDEX, BIAS_INDEX, RUNNING_VAR_INDEX, RUNNING_MEAN_INDEX]:
return True
if node.op_type in ['RNN', 'LSTM', 'GRU'] and index != INPUT_INDEX:
return True

return False

Expand Down Expand Up @@ -560,6 +564,19 @@ def create_matmul_params(my_op: Op):
if weight_tensor:
create_and_connect_product(weight_tensor.name, weight_tensor.dims, my_op, weight_tensor, 'weight')

def create_recurrent_type_params(my_op: Op):
"""
Create products for RNN, LSTM and GRU layer
:param my_op: Connected Graph Op
"""
op = my_op.get_module()
weight_tensor = ParamUtils.get_param(self.model, op, WEIGHT_INDEX)
create_and_connect_product(weight_tensor.name, weight_tensor.dims, my_op, weight_tensor, 'weight_x')

recurrent_weight_tensor = ParamUtils.get_param(self.model, op, RECURRENT_WEIGHT_INDEX)
create_and_connect_product(recurrent_weight_tensor.name, recurrent_weight_tensor.dims, my_op, recurrent_weight_tensor, 'weight_r')

def create_batchnorm_params(my_op: Op):
""" Create products for fusedbatchnorm """
op = my_op.get_module()
Expand Down Expand Up @@ -590,6 +607,9 @@ def handle_default(my_op: Op):
"Conv": create_conv2d_dense_type_params,
"Gemm": create_conv2d_dense_type_params,
"ConvTranspose": create_conv2d_dense_type_params,
"RNN": create_recurrent_type_params,
"LSTM": create_recurrent_type_params,
"GRU": create_recurrent_type_params,
"BatchNormalization": create_batchnorm_params,
"InstanceNormalization": create_batchnorm_params,
"MatMul": create_matmul_params
Expand Down
8 changes: 5 additions & 3 deletions TrainingExtensions/onnx/src/python/aimet_onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2022, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2022-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -57,7 +57,7 @@


OP_TYPES_WITH_PARAMS = ['Conv', 'Gemm', 'ConvTranspose', 'BatchNormalization', 'MatMul', 'Transpose',
'InstanceNormalization']
'InstanceNormalization', 'RNN', 'LSTM', 'GRU']


def remove_nodes_with_type(node_type: str, onnx_graph: onnx.GraphProto):
Expand Down Expand Up @@ -203,6 +203,8 @@ def make_dummy_input(model: ModelProto, dynamic_size: int = 1) -> Dict[str, np.n
shape.append(dim.dim_value)
if shape:
input_dict[name] = np.random.randn(*shape).astype(mapping.TENSOR_TYPE_TO_NP_TYPE[dtype])
else:
input_dict[name] = np.array(np.random.randn(*shape)).astype(mapping.TENSOR_TYPE_TO_NP_TYPE[dtype])
return input_dict


Expand Down Expand Up @@ -297,7 +299,7 @@ def get_graph_intermediate_activations(graph: GraphProto) -> List[str]:
activation_names = []
for node in graph.node:
for name in node.input:
if name not in activation_names and name not in param_names:
if name not in activation_names and name not in param_names and name:
activation_names.append(name)
return activation_names

Expand Down
Loading

0 comments on commit 56b36f4

Please sign in to comment.