Skip to content

Commit

Permalink
Debug BN fold with shared stats
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Tuttle <quic_mtuttle@quicinc.com>
  • Loading branch information
quic-mtuttle committed Nov 19, 2024
1 parent fe63f10 commit f235aa6
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ def get_bn_params(model: ModelProto, bn: NodeProto, channels: int) -> libpymo.BN
runningVar = numpy_helper.to_array(ParamUtils.get_param(model, bn, RUNNING_VAR_INDEX))

epsilon = get_node_attribute(bn, "epsilon")
if epsilon is None:
epsilon = 1e-5 # Default onnx epsilon value
sigma = np.sqrt(runningVar + epsilon)
bn_params.runningVar = np.repeat(sigma.reshape(-1), resize)

Expand Down
29 changes: 20 additions & 9 deletions TrainingExtensions/onnx/src/python/aimet_onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,21 @@ def remove_nodes_with_type(node_type: str, onnx_graph: onnx.GraphProto):
node.output[0] = outputs.name


def remove_node(node: ModelProto, onnx_graph: onnx.GraphProto):
def _prune_unused_initializer(graph: onnx.GraphProto, init_name):
"""
Remove initializer from graph if it is unused
"""
for node in graph.node:
if init_name in node.input:
return # Don't prune if initializer is still used

for initializer in graph.initializer:
if initializer.name == init_name:
graph.initializer.remove(initializer)
break


def remove_node(node: NodeProto, onnx_graph: onnx.GraphProto):
"""
Remove a specific node from graph along with associated initializers
Expand All @@ -105,13 +119,9 @@ def remove_node(node: ModelProto, onnx_graph: onnx.GraphProto):
for outputs in onnx_graph.output:
if outputs.name == node.output[0] and other_node.output[0] == node.input[0]:
other_node.output[0] = outputs.name
inits_to_remove = []
# Remove the node's initializers
for item in onnx_graph.initializer:
if item.name in node.input:
inits_to_remove.append(item)
for item in inits_to_remove:
onnx_graph.initializer.remove(item)

for input_name in node.input:
_prune_unused_initializer(onnx_graph, input_name)


def transpose_tensor(t: TensorProto, axes: Union[List, Tuple]) -> TensorProto:
Expand Down Expand Up @@ -345,9 +355,10 @@ def find_param_in_model_constants(param_name: str, model: ModelProto):
param = attribute.t
param.name = param_name
return param
if node.op_type == 'Identity' and param_name == node.output[0]:
return ParamUtils.get_param(model, node, 0)
return None

assert node.op_type in OP_TYPES_WITH_PARAMS, "Node type not in allowed op types with param list"
if len(node.input) >= param_index + 1:
param_name = node.input[param_index]
param = find_param_in_model_initializers(param_name, model)
Expand Down
58 changes: 58 additions & 0 deletions TrainingExtensions/onnx/test/python/models/models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2577,3 +2577,61 @@ def integer_concat_model():
onnx.checker.check_model(model, True)
return model


def shared_stat_batchnorm_model():
model = helper.make_model(
graph=helper.make_graph(
name='BatchnormModel',
inputs=[helper.make_tensor_value_info('model_input', TensorProto.FLOAT, shape=[10, 10, 8, 8])],
outputs=[helper.make_tensor_value_info('model_output', TensorProto.FLOAT, shape=[10, 10, 8, 8])],
initializer=[
numpy_helper.from_array(np.random.randn(10, 10, 1, 1).astype('float32'), name='conv_1.weight'),
numpy_helper.from_array(np.random.randn(10, 10, 1, 1).astype('float32'), name='conv_2.weight'),
numpy_helper.from_array(np.abs(np.random.randn(10, )).astype('float32'), name='batchnorm.weight'),
numpy_helper.from_array(np.random.randn(10, ).astype('float32'), name='batchnorm.bias'),
numpy_helper.from_array(np.random.randn(10, ).astype('float32'), name='batchnorm.input_mean'),
numpy_helper.from_array(np.abs(np.random.randn(10, )).astype('float32'), name='batchnorm.input_var')
],
nodes=[
helper.make_node(
'Conv',
inputs=['model_input', 'conv_1.weight'],
outputs=['conv_1.output'],
name='conv_1'
),
helper.make_node(
'Identity',
inputs=['batchnorm.input_mean'],
outputs=['batchnorm1.input_mean'],
name='identity_1'
),
helper.make_node(
'BatchNormalization',
inputs=['conv_1.output', 'batchnorm.weight', 'batchnorm.bias', 'batchnorm1.input_mean', 'batchnorm.input_var'],
outputs=['batch_norm_1.output'],
name='batchnorm_1'
),
helper.make_node(
'Conv',
inputs=['batch_norm_1.output', 'conv_2.weight'],
outputs=['conv_2.output'],
name='conv_2'
),
helper.make_node(
'Identity',
inputs=['batchnorm.input_mean'],
outputs=['batchnorm2.input_mean'],
name='identity_2'
),
helper.make_node(
'BatchNormalization',
inputs=['conv_2.output', 'batchnorm.weight', 'batchnorm.bias', 'batchnorm2.input_mean',
'batchnorm.input_var'],
outputs=['model_output'],
name='batchnorm_2'
),
]
)
)
onnx.checker.check_model(model, True)
return model
10 changes: 10 additions & 0 deletions TrainingExtensions/onnx/test/python/test_bn_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,13 @@ def test_single_bn_layer_with_constants(self):
if node.name == "input_mean":
np_tensor = onnx.numpy_helper.to_array(node.attribute[0].t)
assert np.all(np_tensor == np.zeros_like(np_tensor))

def test_fold_with_shared_stats(self):
torch.manual_seed(0)
model = models_for_tests.shared_stat_batchnorm_model()
test_data = np.random.randn(10, 10, 8, 8).astype(np.float32)
baseline_output, folded_output, pairs = get_outputs_after_fold(ONNXModel(model), test_data)

bns_after_fold = {node for node in model.graph.node if node.op_type == "BatchNormalization"}
assert len(bns_after_fold) == 0
assert np.allclose(baseline_output[0], folded_output[0], rtol=1e-2, atol=1e-6)

0 comments on commit f235aa6

Please sign in to comment.