Skip to content

Commit

Permalink
Fix BN layers with parameters as Constant nodes
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Tuttle <quic_mtuttle@quicinc.com>
  • Loading branch information
quic-mtuttle committed Oct 28, 2024
1 parent 35fef14 commit 553646d
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 18 deletions.
32 changes: 14 additions & 18 deletions TrainingExtensions/onnx/src/python/aimet_onnx/batch_norm_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,24 +442,20 @@ def _update_standalone_batchnorm_ops(model: ModelProto):
Update weight and bias of standalone batchnorm ops in the model.
:param model: onnx Model for which batchnorm parameters are to be updated.
"""
initalizer_dict = {initializer.name: idx for idx, initializer in enumerate(model.graph.initializer)}

for node in model.graph.node:
if node.op_type in BatchNormType:

# get parameter names and indices
weight_name, bias_name, running_mean_name, running_var_name = node.input[1:]
idx_w, idx_b = initalizer_dict[weight_name], initalizer_dict[bias_name]
idx_rm, idx_rv = initalizer_dict[running_mean_name], initalizer_dict[running_var_name]

init_w = model.graph.initializer[idx_w]
init_b = model.graph.initializer[idx_b]
init_rm = model.graph.initializer[idx_rm]
init_rv = model.graph.initializer[idx_rv]
attr = get_node_attribute(node, "epsilon")
if attr is None:
init_w, init_b, init_rm, init_rv = [ParamUtils.get_param(model, node, idx) for idx in range(1, 5)]

attr = [item for item in node.attribute if item.name == "epsilon"]
if not attr:
attr = onnx.helper.make_attribute("epsilon", 1e-5) # Default epsilon value
node.attribute.append(attr)
else:
attr = attr[0]

epsilon = attr.f
tensor_w = numpy_helper.to_array(init_w)
Expand All @@ -475,13 +471,13 @@ def _update_standalone_batchnorm_ops(model: ModelProto):
tensor_rv = np.ones(tensor_w.shape, tensor_w.dtype)
attr.f = 0.

init_w = numpy_helper.from_array(tensor_w, weight_name)
init_b = numpy_helper.from_array(tensor_b, bias_name)
init_rm = numpy_helper.from_array(tensor_rm, running_mean_name)
init_rv = numpy_helper.from_array(tensor_rv, running_var_name)
init_w_ = numpy_helper.from_array(tensor_w, weight_name)
init_b_ = numpy_helper.from_array(tensor_b, bias_name)
init_rm_ = numpy_helper.from_array(tensor_rm, running_mean_name)
init_rv_ = numpy_helper.from_array(tensor_rv, running_var_name)

# update initializers
model.graph.initializer[idx_w].CopyFrom(init_w)
model.graph.initializer[idx_b].CopyFrom(init_b)
model.graph.initializer[idx_rm].CopyFrom(init_rm)
model.graph.initializer[idx_rv].CopyFrom(init_rv)
init_w.CopyFrom(init_w_)
init_b.CopyFrom(init_b_)
init_rm.CopyFrom(init_rm_)
init_rv.CopyFrom(init_rv_)
28 changes: 28 additions & 0 deletions TrainingExtensions/onnx/test/python/models/models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2515,3 +2515,31 @@ def batchnorm_model():
onnx.checker.check_model(model, True)
return model

def batchnorm_model_constants():
model = helper.make_model(
graph=helper.make_graph(
name='BatchnormModel',
inputs=[helper.make_tensor_value_info('model_input', TensorProto.FLOAT, shape=[10, 10, 8, 8])],
outputs=[helper.make_tensor_value_info('model_output', TensorProto.FLOAT, shape=[10, 10, 8, 8])],
initializer=[],
nodes=[
helper.make_node('Constant', inputs=[], outputs=["batchnorm.weight"],
value=numpy_helper.from_array(np.abs(np.random.randn(10, )).astype('float32')), name='weight'),
helper.make_node('Constant', inputs=[], outputs=["batchnorm.bias"],
value=numpy_helper.from_array(np.random.randn(10, ).astype('float32')), name='bias'),
helper.make_node('Constant', inputs=[], outputs=["batchnorm.input_mean"],
value=numpy_helper.from_array(np.random.randn(10, ).astype('float32')), name="input_mean"),
helper.make_node('Constant', inputs=[], outputs=["batchnorm.input_var"],
value=numpy_helper.from_array(np.abs(np.random.randn(10, )).astype('float32')), name='input_var'),
helper.make_node(
'BatchNormalization',
inputs=['model_input', 'batchnorm.weight', 'batchnorm.bias', 'batchnorm.input_mean', 'batchnorm.input_var'],
outputs=['model_output'],
name='batchnorm'
),
]
)
)
onnx.checker.check_model(model, True)
return model

16 changes: 16 additions & 0 deletions TrainingExtensions/onnx/test/python/test_bn_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,19 @@ def test_single_batchnorm_layer(self):
if tensor.name == "batchnorm.input_mean":
np_tensor = onnx.numpy_helper.to_array(tensor)
assert np.all(np_tensor == np.zeros_like(np_tensor))

def test_single_bn_layer_with_constants(self):
np.random.seed(0)
model = models_for_tests.batchnorm_model_constants()
dummy_input = make_dummy_input(model)
output = rt.InferenceSession(model.SerializeToString(), providers=providers).run(None, dummy_input)[0]
_update_standalone_batchnorm_ops(model)
output_after_update = rt.InferenceSession(model.SerializeToString(), providers=providers).run(None, dummy_input)[0]
assert np.allclose(output, output_after_update, atol=1e-4)
for node in model.graph.node:
if node.name == "input_var":
np_tensor = onnx.numpy_helper.to_array(node.attribute[0].t)
assert np.all(np_tensor == np.ones_like(np_tensor))
if node.name == "input_mean":
np_tensor = onnx.numpy_helper.to_array(node.attribute[0].t)
assert np.all(np_tensor == np.zeros_like(np_tensor))

0 comments on commit 553646d

Please sign in to comment.