Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix aimet_onnx standalone bn handling #3448

Merged
merged 2 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions TrainingExtensions/onnx/src/python/aimet_onnx/batch_norm_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,23 +442,21 @@ def _update_standalone_batchnorm_ops(model: ModelProto):
Update weight and bias of standalone batchnorm ops in the model.
:param model: onnx Model for which batchnorm parameters are to be updated.
"""
initalizer_dict = {initializer.name: idx for idx, initializer in enumerate(model.graph.initializer)}

for node in model.graph.node:
if node.op_type in BatchNormType:

# get parameter names and indices
weight_name, bias_name, running_mean_name, running_var_name = node.input[1:]
idx_w, idx_b = initalizer_dict[weight_name], initalizer_dict[bias_name]
idx_rm, idx_rv = initalizer_dict[running_mean_name], initalizer_dict[running_var_name]
init_w, init_b, init_rm, init_rv = [ParamUtils.get_param(model, node, idx) for idx in range(1, 5)]

init_w = model.graph.initializer[idx_w]
init_b = model.graph.initializer[idx_b]
init_rm = model.graph.initializer[idx_rm]
init_rv = model.graph.initializer[idx_rv]
attr = [item for item in node.attribute if item.name == "epsilon"]
if not attr:
attr = onnx.helper.make_attribute("epsilon", 1e-5) # Default epsilon value
node.attribute.append(attr)
else:
attr = attr[0]

attr = node.attribute[0]
assert attr.name == 'epsilon'
epsilon = attr.f
tensor_w = numpy_helper.to_array(init_w)
tensor_b = numpy_helper.to_array(init_b)
Expand All @@ -473,13 +471,13 @@ def _update_standalone_batchnorm_ops(model: ModelProto):
tensor_rv = np.ones(tensor_w.shape, tensor_w.dtype)
attr.f = 0.

init_w = numpy_helper.from_array(tensor_w, weight_name)
init_b = numpy_helper.from_array(tensor_b, bias_name)
init_rm = numpy_helper.from_array(tensor_rm, running_mean_name)
init_rv = numpy_helper.from_array(tensor_rv, running_var_name)
init_w_ = numpy_helper.from_array(tensor_w, weight_name)
init_b_ = numpy_helper.from_array(tensor_b, bias_name)
init_rm_ = numpy_helper.from_array(tensor_rm, running_mean_name)
init_rv_ = numpy_helper.from_array(tensor_rv, running_var_name)

# update initializers
model.graph.initializer[idx_w].CopyFrom(init_w)
model.graph.initializer[idx_b].CopyFrom(init_b)
model.graph.initializer[idx_rm].CopyFrom(init_rm)
model.graph.initializer[idx_rv].CopyFrom(init_rv)
init_w.CopyFrom(init_w_)
init_b.CopyFrom(init_b_)
init_rm.CopyFrom(init_rm_)
init_rv.CopyFrom(init_rv_)
28 changes: 28 additions & 0 deletions TrainingExtensions/onnx/test/python/models/models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2515,3 +2515,31 @@ def batchnorm_model():
onnx.checker.check_model(model, True)
return model

def batchnorm_model_constants():
model = helper.make_model(
graph=helper.make_graph(
name='BatchnormModel',
inputs=[helper.make_tensor_value_info('model_input', TensorProto.FLOAT, shape=[10, 10, 8, 8])],
outputs=[helper.make_tensor_value_info('model_output', TensorProto.FLOAT, shape=[10, 10, 8, 8])],
initializer=[],
nodes=[
helper.make_node('Constant', inputs=[], outputs=["batchnorm.weight"],
value=numpy_helper.from_array(np.abs(np.random.randn(10, )).astype('float32')), name='weight'),
helper.make_node('Constant', inputs=[], outputs=["batchnorm.bias"],
value=numpy_helper.from_array(np.random.randn(10, ).astype('float32')), name='bias'),
helper.make_node('Constant', inputs=[], outputs=["batchnorm.input_mean"],
value=numpy_helper.from_array(np.random.randn(10, ).astype('float32')), name="input_mean"),
helper.make_node('Constant', inputs=[], outputs=["batchnorm.input_var"],
value=numpy_helper.from_array(np.abs(np.random.randn(10, )).astype('float32')), name='input_var'),
helper.make_node(
'BatchNormalization',
inputs=['model_input', 'batchnorm.weight', 'batchnorm.bias', 'batchnorm.input_mean', 'batchnorm.input_var'],
outputs=['model_output'],
name='batchnorm'
),
]
)
)
onnx.checker.check_model(model, True)
return model

36 changes: 35 additions & 1 deletion TrainingExtensions/onnx/test/python/test_bn_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@
import pytest
import torch

from aimet_onnx.batch_norm_fold import _find_conv_bn_pairs, find_all_batch_norms_to_fold, fold_all_batch_norms_to_weight
from aimet_onnx.batch_norm_fold import _find_conv_bn_pairs, find_all_batch_norms_to_fold, fold_all_batch_norms_to_weight, _update_standalone_batchnorm_ops
from aimet_onnx.meta.connectedgraph import ConnectedGraph
from aimet_onnx.utils import make_dummy_input

from models import models_for_tests
from models.models_for_tests import BNAfterConv, BNBeforeConv, BNAfterDynamicMatMul, BNAfterConvTranspose, BNAfterConv1d, \
BNAfterLinear, BNBeforeLinear, BNBeforeFlattenLinear, BNBeforeConv1d, BNBeforeConvTranspose, \
MyModel, _convert_to_onnx_no_fold, _convert_to_onnx, initialize_bn_params, \
Expand Down Expand Up @@ -516,3 +518,35 @@ def test_fold_bn_after_dynamic_matmul(self, bias):

assert len(model.graph().node) == layers_orig
assert np.allclose(baseline_output[0], folded_output[0], rtol=1e-2, atol=1e-6)

def test_single_batchnorm_layer(self):
np.random.seed(0)
model = models_for_tests.batchnorm_model()
dummy_input = make_dummy_input(model)
output = rt.InferenceSession(model.SerializeToString(), providers=providers).run(None, dummy_input)[0]
_update_standalone_batchnorm_ops(model)
output_after_update = rt.InferenceSession(model.SerializeToString(), providers=providers).run(None, dummy_input)[0]
assert np.allclose(output, output_after_update, atol=1e-4)
for tensor in model.graph.initializer:
if tensor.name == "batchnorm.input_var":
np_tensor = onnx.numpy_helper.to_array(tensor)
assert np.all(np_tensor == np.ones_like(np_tensor))
if tensor.name == "batchnorm.input_mean":
np_tensor = onnx.numpy_helper.to_array(tensor)
assert np.all(np_tensor == np.zeros_like(np_tensor))

def test_single_bn_layer_with_constants(self):
np.random.seed(0)
model = models_for_tests.batchnorm_model_constants()
dummy_input = make_dummy_input(model)
output = rt.InferenceSession(model.SerializeToString(), providers=providers).run(None, dummy_input)[0]
_update_standalone_batchnorm_ops(model)
output_after_update = rt.InferenceSession(model.SerializeToString(), providers=providers).run(None, dummy_input)[0]
assert np.allclose(output, output_after_update, atol=1e-4)
for node in model.graph.node:
if node.name == "input_var":
np_tensor = onnx.numpy_helper.to_array(node.attribute[0].t)
assert np.all(np_tensor == np.ones_like(np_tensor))
if node.name == "input_mean":
np_tensor = onnx.numpy_helper.to_array(node.attribute[0].t)
assert np.all(np_tensor == np.zeros_like(np_tensor))
Loading