Skip to content

Commit

Permalink
Debug standalone BN handling
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Tuttle <quic_mtuttle@quicinc.com>
  • Loading branch information
quic-mtuttle committed Oct 28, 2024
1 parent bb18c5d commit 35fef14
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,11 @@ def _update_standalone_batchnorm_ops(model: ModelProto):
init_b = model.graph.initializer[idx_b]
init_rm = model.graph.initializer[idx_rm]
init_rv = model.graph.initializer[idx_rv]
attr = get_node_attribute(node, "epsilon")
if attr is None:
attr = onnx.helper.make_attribute("epsilon", 1e-5) # Default epsilon value
node.attribute.append(attr)

attr = node.attribute[0]
assert attr.name == 'epsilon'
epsilon = attr.f
tensor_w = numpy_helper.to_array(init_w)
tensor_b = numpy_helper.to_array(init_b)
Expand Down
20 changes: 19 additions & 1 deletion TrainingExtensions/onnx/test/python/test_bn_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@
import pytest
import torch

from aimet_onnx.batch_norm_fold import _find_conv_bn_pairs, find_all_batch_norms_to_fold, fold_all_batch_norms_to_weight
from aimet_onnx.batch_norm_fold import _find_conv_bn_pairs, find_all_batch_norms_to_fold, fold_all_batch_norms_to_weight, _update_standalone_batchnorm_ops
from aimet_onnx.meta.connectedgraph import ConnectedGraph
from aimet_onnx.utils import make_dummy_input

from models import models_for_tests
from models.models_for_tests import BNAfterConv, BNBeforeConv, BNAfterDynamicMatMul, BNAfterConvTranspose, BNAfterConv1d, \
BNAfterLinear, BNBeforeLinear, BNBeforeFlattenLinear, BNBeforeConv1d, BNBeforeConvTranspose, \
MyModel, _convert_to_onnx_no_fold, _convert_to_onnx, initialize_bn_params, \
Expand Down Expand Up @@ -516,3 +518,19 @@ def test_fold_bn_after_dynamic_matmul(self, bias):

assert len(model.graph().node) == layers_orig
assert np.allclose(baseline_output[0], folded_output[0], rtol=1e-2, atol=1e-6)

def test_single_batchnorm_layer(self):
np.random.seed(0)
model = models_for_tests.batchnorm_model()
dummy_input = make_dummy_input(model)
output = rt.InferenceSession(model.SerializeToString(), providers=providers).run(None, dummy_input)[0]
_update_standalone_batchnorm_ops(model)
output_after_update = rt.InferenceSession(model.SerializeToString(), providers=providers).run(None, dummy_input)[0]
assert np.allclose(output, output_after_update, atol=1e-4)
for tensor in model.graph.initializer:
if tensor.name == "batchnorm.input_var":
np_tensor = onnx.numpy_helper.to_array(tensor)
assert np.all(np_tensor == np.ones_like(np_tensor))
if tensor.name == "batchnorm.input_mean":
np_tensor = onnx.numpy_helper.to_array(tensor)
assert np.all(np_tensor == np.zeros_like(np_tensor))

0 comments on commit 35fef14

Please sign in to comment.