diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/batch_norm_fold.py b/TrainingExtensions/onnx/src/python/aimet_onnx/batch_norm_fold.py index e8fe7b93d7d..992e8d03c34 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/batch_norm_fold.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/batch_norm_fold.py @@ -442,24 +442,20 @@ def _update_standalone_batchnorm_ops(model: ModelProto): Update weight and bias of standalone batchnorm ops in the model. :param model: onnx Model for which batchnorm parameters are to be updated. """ - initalizer_dict = {initializer.name: idx for idx, initializer in enumerate(model.graph.initializer)} for node in model.graph.node: if node.op_type in BatchNormType: # get parameter names and indices weight_name, bias_name, running_mean_name, running_var_name = node.input[1:] - idx_w, idx_b = initalizer_dict[weight_name], initalizer_dict[bias_name] - idx_rm, idx_rv = initalizer_dict[running_mean_name], initalizer_dict[running_var_name] - - init_w = model.graph.initializer[idx_w] - init_b = model.graph.initializer[idx_b] - init_rm = model.graph.initializer[idx_rm] - init_rv = model.graph.initializer[idx_rv] - attr = get_node_attribute(node, "epsilon") - if attr is None: + init_w, init_b, init_rm, init_rv = [ParamUtils.get_param(model, node, idx) for idx in range(1, 5)] + + attr = [item for item in node.attribute if item.name == "epsilon"] + if not attr: attr = onnx.helper.make_attribute("epsilon", 1e-5) # Default epsilon value node.attribute.append(attr) + else: + attr = attr[0] epsilon = attr.f tensor_w = numpy_helper.to_array(init_w) @@ -475,13 +471,13 @@ def _update_standalone_batchnorm_ops(model: ModelProto): tensor_rv = np.ones(tensor_w.shape, tensor_w.dtype) attr.f = 0. - init_w = numpy_helper.from_array(tensor_w, weight_name) - init_b = numpy_helper.from_array(tensor_b, bias_name) - init_rm = numpy_helper.from_array(tensor_rm, running_mean_name) - init_rv = numpy_helper.from_array(tensor_rv, running_var_name) + init_w_ = numpy_helper.from_array(tensor_w, weight_name) + init_b_ = numpy_helper.from_array(tensor_b, bias_name) + init_rm_ = numpy_helper.from_array(tensor_rm, running_mean_name) + init_rv_ = numpy_helper.from_array(tensor_rv, running_var_name) # update initializers - model.graph.initializer[idx_w].CopyFrom(init_w) - model.graph.initializer[idx_b].CopyFrom(init_b) - model.graph.initializer[idx_rm].CopyFrom(init_rm) - model.graph.initializer[idx_rv].CopyFrom(init_rv) + init_w.CopyFrom(init_w_) + init_b.CopyFrom(init_b_) + init_rm.CopyFrom(init_rm_) + init_rv.CopyFrom(init_rv_) diff --git a/TrainingExtensions/onnx/test/python/models/models_for_tests.py b/TrainingExtensions/onnx/test/python/models/models_for_tests.py index c1da74cb4ae..17f05baa7dc 100644 --- a/TrainingExtensions/onnx/test/python/models/models_for_tests.py +++ b/TrainingExtensions/onnx/test/python/models/models_for_tests.py @@ -2515,3 +2515,31 @@ def batchnorm_model(): onnx.checker.check_model(model, True) return model +def batchnorm_model_constants(): + model = helper.make_model( + graph=helper.make_graph( + name='BatchnormModel', + inputs=[helper.make_tensor_value_info('model_input', TensorProto.FLOAT, shape=[10, 10, 8, 8])], + outputs=[helper.make_tensor_value_info('model_output', TensorProto.FLOAT, shape=[10, 10, 8, 8])], + initializer=[], + nodes=[ + helper.make_node('Constant', inputs=[], outputs=["batchnorm.weight"], + value=numpy_helper.from_array(np.abs(np.random.randn(10, )).astype('float32')), name='weight'), + helper.make_node('Constant', inputs=[], outputs=["batchnorm.bias"], + value=numpy_helper.from_array(np.random.randn(10, ).astype('float32')), name='bias'), + helper.make_node('Constant', inputs=[], outputs=["batchnorm.input_mean"], + value=numpy_helper.from_array(np.random.randn(10, ).astype('float32')), name="input_mean"), + helper.make_node('Constant', inputs=[], outputs=["batchnorm.input_var"], + value=numpy_helper.from_array(np.abs(np.random.randn(10, )).astype('float32')), name='input_var'), + helper.make_node( + 'BatchNormalization', + inputs=['model_input', 'batchnorm.weight', 'batchnorm.bias', 'batchnorm.input_mean', 'batchnorm.input_var'], + outputs=['model_output'], + name='batchnorm' + ), + ] + ) + ) + onnx.checker.check_model(model, True) + return model + diff --git a/TrainingExtensions/onnx/test/python/test_bn_fold.py b/TrainingExtensions/onnx/test/python/test_bn_fold.py index 606a7c20115..0df8f3b7e71 100644 --- a/TrainingExtensions/onnx/test/python/test_bn_fold.py +++ b/TrainingExtensions/onnx/test/python/test_bn_fold.py @@ -534,3 +534,19 @@ def test_single_batchnorm_layer(self): if tensor.name == "batchnorm.input_mean": np_tensor = onnx.numpy_helper.to_array(tensor) assert np.all(np_tensor == np.zeros_like(np_tensor)) + + def test_single_bn_layer_with_constants(self): + np.random.seed(0) + model = models_for_tests.batchnorm_model_constants() + dummy_input = make_dummy_input(model) + output = rt.InferenceSession(model.SerializeToString(), providers=providers).run(None, dummy_input)[0] + _update_standalone_batchnorm_ops(model) + output_after_update = rt.InferenceSession(model.SerializeToString(), providers=providers).run(None, dummy_input)[0] + assert np.allclose(output, output_after_update, atol=1e-4) + for node in model.graph.node: + if node.name == "input_var": + np_tensor = onnx.numpy_helper.to_array(node.attribute[0].t) + assert np.all(np_tensor == np.ones_like(np_tensor)) + if node.name == "input_mean": + np_tensor = onnx.numpy_helper.to_array(node.attribute[0].t) + assert np.all(np_tensor == np.zeros_like(np_tensor))