Skip to content

Commit

Permalink
Adding functional batch_norm to BatchNorm2d substitution (sony#868)
Browse files Browse the repository at this point in the history
* Adding functional batch_norm to BatchNorm2d substitution
  • Loading branch information
edenlum authored Dec 11, 2023
1 parent 0eac225 commit db41d92
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 41 deletions.
27 changes: 24 additions & 3 deletions model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,19 +299,24 @@ def get_next_nodes(self,
return [edges_list.sink_node for edges_list in self.out_edges(node_obj)]

def get_prev_nodes(self,
node_obj: BaseNode) -> List[BaseNode]:
node_obj: BaseNode,
sink_index_sorted: bool = False) -> List[BaseNode]:
"""
Get previous nodes (in a topological order) of a node.
Args:
node_obj: Node to get its previous nodes.
sink_index_sorted: Whether to sort the returned list by the sink_index of the edges.
Returns:
List of input nodes objects.
"""

return [edges_list.source_node for edges_list in self.incoming_edges(node_obj)]
if sink_index_sorted:
sort_attr = 'sink_index'
else:
sort_attr = None
return [edges_list.source_node for edges_list in self.incoming_edges(node_obj, sort_by_attr=sort_attr)]

def reconnect_out_edges(self,
current_node: BaseNode,
Expand Down Expand Up @@ -705,3 +710,19 @@ def is_single_activation_cfg(self):
"""
return all([n.is_all_activation_candidates_equal() for n in self.nodes])

def replace_node(self, node_to_replace: BaseNode, new_node: BaseNode):
"""
Replaces a node in the graph with a new node.
Args:
node_to_replace: The node to replace.
new_node: The new node to replace with.
"""
self.add_node(new_node)
self.reconnect_out_edges(node_to_replace, new_node)
self.reconnect_in_edges(node_to_replace, new_node)
self.replace_output_node(node_to_replace, new_node)
self.replace_input_node(node_to_replace, new_node)
self.remove_node(node_to_replace)
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from torch import nn
import torch.nn.functional as F

from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common import BaseNode, Graph
from model_compression_toolkit.core.pytorch.constants import *
from model_compression_toolkit.logger import Logger


class FunctionalBatchNorm(common.BaseSubstitution):
"""
Replace functional batch_norm with BatchNorm2d.
"""

def __init__(self):
"""
Matches: functional batch_norm
"""
bn_node = NodeOperationMatcher(F.batch_norm)
super().__init__(matcher_instance=bn_node)

def get_attributes_from_inputs(self, graph: Graph, node: BaseNode) -> dict:
input_nodes = graph.get_prev_nodes(node, sink_index_sorted=True)

if len(input_nodes) == 5:
return {
MOVING_MEAN: list(input_nodes[1].weights.values())[0],
MOVING_VARIANCE: list(input_nodes[2].weights.values())[0],
GAMMA: list(input_nodes[3].weights.values())[0],
BETA: list(input_nodes[4].weights.values())[0]
}
else:
Logger.warning(f'functional batch_norm is only folded in the 5 inputs case (input, mean, var, gamma, beta),'
f'got {len(input_nodes)}')
return {}

def substitute(self,
graph: Graph,
node: BaseNode) -> Graph:
"""
Substitute functional.batch_norm and its inputs with BatchNorm2d.
Args:
graph: Graph we apply the substitution on.
node: node that match the pattern in the substitution init.
Returns:
Graph after applying the substitution.
"""
# if the input is not a 4D tensor, we can't substitute it with BatchNorm2d
if len(node.input_shape[0]) != 4:
return graph
out_channels = node.output_shape[0][1]

bn_node_weights = self.get_attributes_from_inputs(graph, node)
if not bn_node_weights:
return graph
new_batchnorm2d = BaseNode(name=node.name + '_into_BatchNorm2d',
framework_attr={NUM_FEATURES: out_channels,
EPSILON: EPSILON_VAL,
MOMENTUM: MOMENTUM_VAL},
input_shape=node.output_shape,
output_shape=node.output_shape,
weights=bn_node_weights,
layer_class=nn.BatchNorm2d)

num_nodes_before_substitution = len(graph.nodes)
num_edges_before_substitution = len(graph.edges)

batch_norm_consts = graph.get_prev_nodes(node)[1:]
for const in batch_norm_consts:
graph.remove_edge(const, node)
graph.remove_node(const)

graph.replace_node(node, new_batchnorm2d)

assert num_nodes_before_substitution - len(graph.nodes) == len(batch_norm_consts)
assert num_edges_before_substitution - len(graph.edges) == len(batch_norm_consts)

return graph
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
pytorch_batchnorm_reconstruction
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.batchnorm_refusing import \
pytorch_batchnorm_refusing
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.functional_batch_norm import \
FunctionalBatchNorm
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.linear_collapsing import \
pytorch_linear_collapsing
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.multi_head_attention_decomposition \
Expand Down Expand Up @@ -243,7 +245,8 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
return [ReshapeWithStaticShapes(),
MultiHeadAttentionDecomposition(),
PermuteCallMethod(),
ConstantHolderConv(fw_info)]
ConstantHolderConv(fw_info),
FunctionalBatchNorm()]

def get_substitutions_pre_statistics_collection(self,
quant_config: QuantizationConfig
Expand Down
52 changes: 30 additions & 22 deletions tests/pytorch_tests/model_tests/feature_models/bn_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import torch
from torch import nn
import numpy as np
from model_compression_toolkit.core.pytorch.utils import set_model, to_torch_tensor, \
torch_tensor_to_numpy
Expand All @@ -22,16 +23,21 @@
"""
This test checks the BatchNorm folding feature, plus adding a residual connection.
"""
class BNFoldingNet(torch.nn.Module):
def __init__(self, test_layer, fold_applied):
class BNFoldingNet(nn.Module):
def __init__(self, test_layer, functional, fold_applied):
super(BNFoldingNet, self).__init__()
self.conv1 = test_layer
self.fold_applied = fold_applied
self.bn = torch.nn.BatchNorm2d(test_layer.out_channels)
self.bn = nn.BatchNorm2d(test_layer.out_channels)
self.functional = functional

def forward(self, inp):
x1 = self.conv1(inp)
x = self.bn(x1)
if self.functional:
x = nn.functional.batch_norm(x1, self.bn.running_mean, self.bn.running_var, self.bn.weight, self.bn.bias,
training=self.bn.training, momentum=self.bn.momentum, eps=self.bn.eps)
else:
x = self.bn(x1)
x = torch.relu(x)
if not self.fold_applied:
x = x + x1
Expand All @@ -42,13 +48,14 @@ class BNFoldingNetTest(BasePytorchTest):
"""
This test checks the BatchNorm folding feature, plus adding a residual connection.
"""
def __init__(self, unit_test, test_layer, fold_applied=True, float_reconstruction_error=1e-6):
def __init__(self, unit_test, test_layer, functional, fold_applied=True, float_reconstruction_error=1e-6):
super().__init__(unit_test, float_reconstruction_error)
self.test_layer = test_layer
self.fold_applied = fold_applied
self.functional = functional

def create_feature_network(self, input_shape):
return BNFoldingNet(self.test_layer, self.fold_applied)
return BNFoldingNet(self.test_layer, self.functional, self.fold_applied)

def get_tpc(self):
return {'no_quantization': super().get_tpc()['no_quantization']}
Expand All @@ -63,29 +70,29 @@ def compare(self, quantized_models, float_model, input_x=None, quantization_info
out_float = torch_tensor_to_numpy(float_model(*input_x))
out_quant = torch_tensor_to_numpy(quant_model(*input_x))

is_bn_in_model = torch.nn.BatchNorm2d in [type(module) for name, module in quant_model.named_modules()]
is_bn_in_model = nn.BatchNorm2d in [type(module) for name, module in quant_model.named_modules()]
self.unit_test.assertTrue(self.fold_applied is not is_bn_in_model)
self.unit_test.assertTrue(np.isclose(out_quant, out_float, atol=1e-6, rtol=1e-4).all())


class BNForwardFoldingNet(torch.nn.Module):
class BNForwardFoldingNet(nn.Module):
def __init__(self, test_layer, add_bn=False, is_dw=False):
super(BNForwardFoldingNet, self).__init__()
if is_dw:
self.bn = torch.nn.Conv2d(3, 3, 1, groups=3)
self.bn = nn.Conv2d(3, 3, 1, groups=3)
else:
self.bn = torch.nn.BatchNorm2d(3)
torch.nn.init.uniform_(self.bn.weight, 0.02, 1.05)
torch.nn.init.uniform_(self.bn.bias, -1.2, 1.05)
torch.nn.init.uniform_(self.bn.running_var, 0.02, 1.05)
torch.nn.init.uniform_(self.bn.running_mean, -1.2, 1.05)
self.bn = nn.BatchNorm2d(3)
nn.init.uniform_(self.bn.weight, 0.02, 1.05)
nn.init.uniform_(self.bn.bias, -1.2, 1.05)
nn.init.uniform_(self.bn.running_var, 0.02, 1.05)
nn.init.uniform_(self.bn.running_mean, -1.2, 1.05)
self.conv = test_layer
if add_bn:
self.bn2 = torch.nn.BatchNorm2d(test_layer.out_channels)
torch.nn.init.uniform_(self.bn2.weight, 0.02, 1.05)
torch.nn.init.uniform_(self.bn2.bias, -1.2, 1.05)
torch.nn.init.uniform_(self.bn2.running_var, 0.02, 1.05)
torch.nn.init.uniform_(self.bn2.running_mean, -1.2, 1.05)
self.bn2 = nn.BatchNorm2d(test_layer.out_channels)
nn.init.uniform_(self.bn2.weight, 0.02, 1.05)
nn.init.uniform_(self.bn2.bias, -1.2, 1.05)
nn.init.uniform_(self.bn2.running_var, 0.02, 1.05)
nn.init.uniform_(self.bn2.running_mean, -1.2, 1.05)
else:
self.bn2 = None

Expand All @@ -106,6 +113,7 @@ class BNForwardFoldingNetTest(BasePytorchTest):
def __init__(self, unit_test, test_layer, fold_applied=True, add_bn=False, is_dw=False):
super().__init__(unit_test, float_reconstruction_error=1e-6, val_batch_size=2)
self.test_layer = test_layer
self.bn_layer = nn.BatchNorm2d
self.fold_applied = fold_applied
self.add_bn = add_bn
self.is_dw = is_dw
Expand All @@ -125,10 +133,10 @@ def compare(self, quantized_models, float_model, input_x=None, quantization_info
set_model(quant_model)

if self.is_dw:
is_bn_in_model = (sum([type(module) is torch.nn.Conv2d for name, module in float_model.named_modules()]) ==
sum([type(module) is torch.nn.Conv2d for name, module in quant_model.named_modules()]))
is_bn_in_model = (sum([type(module) is nn.Conv2d for name, module in float_model.named_modules()]) ==
sum([type(module) is nn.Conv2d for name, module in quant_model.named_modules()]))
else:
is_bn_in_model = torch.nn.BatchNorm2d in [type(module) for name, module in quant_model.named_modules()]
is_bn_in_model = nn.BatchNorm2d in [type(module) for name, module in quant_model.named_modules()]

self.unit_test.assertTrue(self.fold_applied is not is_bn_in_model)

Expand Down
Loading

0 comments on commit db41d92

Please sign in to comment.