Skip to content

Commit

Permalink
Add support for onnx export to models with multiple inputs/outputs (#…
Browse files Browse the repository at this point in the history
…1223)

* Add support for onnx export to models with multiple inputs/outputs

* Add onnxruntime extensions to pytorch workflows

---------

Co-authored-by: reuvenp <reuvenp@altair-semi.com>
  • Loading branch information
reuvenperetz and reuvenp authored Sep 22, 2024
1 parent bb35a81 commit ccab3ad
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_pytorch_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install torch==${{ inputs.torch-version }} torchvision onnx onnxruntime
pip install torch==${{ inputs.torch-version }} torchvision onnx onnxruntime onnxruntime-extensions
pip install pytest
- name: Run unittests
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ def export(self) -> None:
else:
Logger.info(f"Exporting fake-quant onnx model: {self.save_model_path}")

model_input = to_torch_tensor(next(self.repr_dataset())[0])
model_input = to_torch_tensor(next(self.repr_dataset()))

if hasattr(self.model, 'metadata'):
onnx_bytes = BytesIO()
torch.onnx.export(self.model,
model_input,
tuple(model_input) if isinstance(model_input, list) else model_input,
onnx_bytes,
opset_version=self._onnx_opset_version,
verbose=False,
Expand All @@ -107,7 +107,7 @@ def export(self) -> None:
onnx.save_model(onnx_model, self.save_model_path)
else:
torch.onnx.export(self.model,
model_input,
tuple(model_input) if isinstance(model_input, list) else model_input,
self.save_model_path,
opset_version=self._onnx_opset_version,
verbose=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from onnx import numpy_helper

import model_compression_toolkit as mct
from mct_quantizers import get_ort_session_options
from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import DEFAULT_ONNX_OPSET_VERSION

from tests.pytorch_tests.exporter_tests.base_pytorch_export_test import BasePytorchExportTest
Expand All @@ -40,16 +41,29 @@ def load_exported_model(self, filepath):
onnx.checker.check_model(filepath)
return onnx.load(filepath)

def infer(self, model, images):
ort_session = onnxruntime.InferenceSession(model.SerializeToString())
def infer(self, model, inputs):
ort_session = onnxruntime.InferenceSession(
model.SerializeToString(),
get_ort_session_options(),
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(images)}
onnx_output = ort_session.run(None, ort_inputs)
return onnx_output[0]
# Prepare inputs
if isinstance(inputs, list):
ort_inputs = {input.name: to_numpy(tensor) for input, tensor in zip(ort_session.get_inputs(), inputs)}
elif isinstance(inputs, dict):
ort_inputs = {name: to_numpy(tensor) for name, tensor in inputs.items()}
else:
raise ValueError("Inputs must be a list or a dictionary")

output_names = [output.name for output in ort_session.get_outputs()]
onnx_outputs = ort_session.run(output_names, ort_inputs)
output_dict = dict(zip(output_names, onnx_outputs))

return output_dict

def _get_onnx_node_by_type(self, onnx_model, op_type):
return [n for n in onnx_model.graph.node if n.op_type == op_type]
Expand Down
38 changes: 38 additions & 0 deletions tests/pytorch_tests/exporter_tests/test_onnx_multiple_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from tests.pytorch_tests.exporter_tests.base_pytorch_onnx_export_test import BasePytorchONNXExportTest
from torch import nn

class TestExportONNXMultipleInputs(BasePytorchONNXExportTest):
def get_model(self):
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(3, 3, kernel_size=3, padding=1)

def forward(self, input1, input2):
x1 = self.conv1(input1)
x2 = self.conv2(input2)
return x1 + x2

return Model()

def get_input_shapes(self):
return [(1, 3, 8, 8), (1, 3, 8, 8)]

def compare(self, loaded_model, quantized_model, quantization_info):
assert len(loaded_model.graph.input)==2, f"Model expected to have two inputs but has {len(loaded_model.graph.input)}"
self.infer(loaded_model, next(self.get_dataset()))
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from tests.pytorch_tests.exporter_tests.base_pytorch_onnx_export_test import BasePytorchONNXExportTest
from torch import nn

class TestExportONNXMultipleInputsAndOutputs(BasePytorchONNXExportTest):
def get_model(self):
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(3, 3, kernel_size=3, padding=1)

def forward(self, input1, input2):
x1 = self.conv1(input1)
x2 = self.conv2(input2)
return x1, x2

return Model()

def get_input_shapes(self):
return [(1, 3, 8, 8), (1, 3, 8, 8)]

def compare(self, loaded_model, quantized_model, quantization_info):
assert len(loaded_model.graph.input) == 2, f"Model expected to have two inputs but has {len(loaded_model.graph.input)}"
assert len(loaded_model.graph.output) == 2, f"Model expected to have two outputs but has {len(loaded_model.graph.output)}"
self.infer(loaded_model, next(self.get_dataset()))
38 changes: 38 additions & 0 deletions tests/pytorch_tests/exporter_tests/test_onnx_multiple_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from tests.pytorch_tests.exporter_tests.base_pytorch_onnx_export_test import BasePytorchONNXExportTest
from torch import nn

class TestExportONNXMultipleOutputs(BasePytorchONNXExportTest):
def get_model(self):
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(3, 3, kernel_size=3, padding=1)

def forward(self, input1):
x1 = self.conv1(input1)
x2 = self.conv2(input1)
return x1, x2

return Model()

def get_input_shapes(self):
return [(1, 3, 8, 8)]

def compare(self, loaded_model, quantized_model, quantization_info):
assert len(loaded_model.graph.output)==2, f"Model expected to have two outputs but has {len(loaded_model.graph.output)}"
self.infer(loaded_model, next(self.get_dataset()))
12 changes: 12 additions & 0 deletions tests/pytorch_tests/exporter_tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
TestExportONNXWeightSymmetric2BitsQuantizers
from tests.pytorch_tests.exporter_tests.custom_ops_tests.test_export_uniform_onnx_quantizers import \
TestExportONNXWeightUniform2BitsQuantizers
from tests.pytorch_tests.exporter_tests.test_onnx_multiple_inputs import TestExportONNXMultipleInputs
from tests.pytorch_tests.exporter_tests.test_onnx_multiple_inputs_and_outputs import \
TestExportONNXMultipleInputsAndOutputs
from tests.pytorch_tests.exporter_tests.test_onnx_multiple_outputs import TestExportONNXMultipleOutputs


class PytorchExporterTestsRunner(unittest.TestCase):
Expand Down Expand Up @@ -55,4 +59,12 @@ def test_lut_sym2bits_custom_quantizer_onnx(self):
TestExportONNXWeightLUTSymmetric2BitsQuantizers().run_test()
TestExportONNXWeightLUTSymmetric2BitsQuantizers(onnx_opset_version=16).run_test()

def test_multiple_inputs_onnx(self):
TestExportONNXMultipleInputs().run_test()

def test_multiple_outputs_onnx(self):
TestExportONNXMultipleOutputs().run_test()

def test_multiple_inputs_and_outputs_onnx(self):
TestExportONNXMultipleInputsAndOutputs().run_test()

0 comments on commit ccab3ad

Please sign in to comment.