From db41d922302d7755eadf63b022f0f0dbe06b816b Mon Sep 17 00:00:00 2001 From: Eden Lumbroso Date: Mon, 11 Dec 2023 13:47:09 +0200 Subject: [PATCH] Adding functional batch_norm to BatchNorm2d substitution (#868) * Adding functional batch_norm to BatchNorm2d substitution --- .../core/common/graph/base_graph.py | 27 +++++- .../substitutions/functional_batch_norm.py | 94 +++++++++++++++++++ .../core/pytorch/pytorch_implementation.py | 5 +- .../feature_models/bn_folding_test.py | 52 +++++----- .../model_tests/test_feature_models_runner.py | 46 ++++++--- 5 files changed, 183 insertions(+), 41 deletions(-) create mode 100644 model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py diff --git a/model_compression_toolkit/core/common/graph/base_graph.py b/model_compression_toolkit/core/common/graph/base_graph.py index 0d8ed1f0e..b18b767fe 100644 --- a/model_compression_toolkit/core/common/graph/base_graph.py +++ b/model_compression_toolkit/core/common/graph/base_graph.py @@ -299,19 +299,24 @@ def get_next_nodes(self, return [edges_list.sink_node for edges_list in self.out_edges(node_obj)] def get_prev_nodes(self, - node_obj: BaseNode) -> List[BaseNode]: + node_obj: BaseNode, + sink_index_sorted: bool = False) -> List[BaseNode]: """ Get previous nodes (in a topological order) of a node. Args: node_obj: Node to get its previous nodes. + sink_index_sorted: Whether to sort the returned list by the sink_index of the edges. Returns: List of input nodes objects. """ - - return [edges_list.source_node for edges_list in self.incoming_edges(node_obj)] + if sink_index_sorted: + sort_attr = 'sink_index' + else: + sort_attr = None + return [edges_list.source_node for edges_list in self.incoming_edges(node_obj, sort_by_attr=sort_attr)] def reconnect_out_edges(self, current_node: BaseNode, @@ -705,3 +710,19 @@ def is_single_activation_cfg(self): """ return all([n.is_all_activation_candidates_equal() for n in self.nodes]) + + def replace_node(self, node_to_replace: BaseNode, new_node: BaseNode): + """ + Replaces a node in the graph with a new node. + + Args: + node_to_replace: The node to replace. + new_node: The new node to replace with. + + """ + self.add_node(new_node) + self.reconnect_out_edges(node_to_replace, new_node) + self.reconnect_in_edges(node_to_replace, new_node) + self.replace_output_node(node_to_replace, new_node) + self.replace_input_node(node_to_replace, new_node) + self.remove_node(node_to_replace) diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py new file mode 100644 index 000000000..23db177a6 --- /dev/null +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py @@ -0,0 +1,94 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from torch import nn +import torch.nn.functional as F + +from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher +from model_compression_toolkit.core import common +from model_compression_toolkit.core.common import BaseNode, Graph +from model_compression_toolkit.core.pytorch.constants import * +from model_compression_toolkit.logger import Logger + + +class FunctionalBatchNorm(common.BaseSubstitution): + """ + Replace functional batch_norm with BatchNorm2d. + """ + + def __init__(self): + """ + Matches: functional batch_norm + """ + bn_node = NodeOperationMatcher(F.batch_norm) + super().__init__(matcher_instance=bn_node) + + def get_attributes_from_inputs(self, graph: Graph, node: BaseNode) -> dict: + input_nodes = graph.get_prev_nodes(node, sink_index_sorted=True) + + if len(input_nodes) == 5: + return { + MOVING_MEAN: list(input_nodes[1].weights.values())[0], + MOVING_VARIANCE: list(input_nodes[2].weights.values())[0], + GAMMA: list(input_nodes[3].weights.values())[0], + BETA: list(input_nodes[4].weights.values())[0] + } + else: + Logger.warning(f'functional batch_norm is only folded in the 5 inputs case (input, mean, var, gamma, beta),' + f'got {len(input_nodes)}') + return {} + + def substitute(self, + graph: Graph, + node: BaseNode) -> Graph: + """ + Substitute functional.batch_norm and its inputs with BatchNorm2d. + Args: + graph: Graph we apply the substitution on. + node: node that match the pattern in the substitution init. + + Returns: + Graph after applying the substitution. + """ + # if the input is not a 4D tensor, we can't substitute it with BatchNorm2d + if len(node.input_shape[0]) != 4: + return graph + out_channels = node.output_shape[0][1] + + bn_node_weights = self.get_attributes_from_inputs(graph, node) + if not bn_node_weights: + return graph + new_batchnorm2d = BaseNode(name=node.name + '_into_BatchNorm2d', + framework_attr={NUM_FEATURES: out_channels, + EPSILON: EPSILON_VAL, + MOMENTUM: MOMENTUM_VAL}, + input_shape=node.output_shape, + output_shape=node.output_shape, + weights=bn_node_weights, + layer_class=nn.BatchNorm2d) + + num_nodes_before_substitution = len(graph.nodes) + num_edges_before_substitution = len(graph.edges) + + batch_norm_consts = graph.get_prev_nodes(node)[1:] + for const in batch_norm_consts: + graph.remove_edge(const, node) + graph.remove_node(const) + + graph.replace_node(node, new_batchnorm2d) + + assert num_nodes_before_substitution - len(graph.nodes) == len(batch_norm_consts) + assert num_edges_before_substitution - len(graph.edges) == len(batch_norm_consts) + + return graph diff --git a/model_compression_toolkit/core/pytorch/pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pytorch_implementation.py index 5947c01c4..8004f6e78 100644 --- a/model_compression_toolkit/core/pytorch/pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pytorch_implementation.py @@ -48,6 +48,8 @@ pytorch_batchnorm_reconstruction from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.batchnorm_refusing import \ pytorch_batchnorm_refusing +from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.functional_batch_norm import \ + FunctionalBatchNorm from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.linear_collapsing import \ pytorch_linear_collapsing from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.multi_head_attention_decomposition \ @@ -243,7 +245,8 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List return [ReshapeWithStaticShapes(), MultiHeadAttentionDecomposition(), PermuteCallMethod(), - ConstantHolderConv(fw_info)] + ConstantHolderConv(fw_info), + FunctionalBatchNorm()] def get_substitutions_pre_statistics_collection(self, quant_config: QuantizationConfig diff --git a/tests/pytorch_tests/model_tests/feature_models/bn_folding_test.py b/tests/pytorch_tests/model_tests/feature_models/bn_folding_test.py index bd2533ff2..540b30f9b 100644 --- a/tests/pytorch_tests/model_tests/feature_models/bn_folding_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/bn_folding_test.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== import torch +from torch import nn import numpy as np from model_compression_toolkit.core.pytorch.utils import set_model, to_torch_tensor, \ torch_tensor_to_numpy @@ -22,16 +23,21 @@ """ This test checks the BatchNorm folding feature, plus adding a residual connection. """ -class BNFoldingNet(torch.nn.Module): - def __init__(self, test_layer, fold_applied): +class BNFoldingNet(nn.Module): + def __init__(self, test_layer, functional, fold_applied): super(BNFoldingNet, self).__init__() self.conv1 = test_layer self.fold_applied = fold_applied - self.bn = torch.nn.BatchNorm2d(test_layer.out_channels) + self.bn = nn.BatchNorm2d(test_layer.out_channels) + self.functional = functional def forward(self, inp): x1 = self.conv1(inp) - x = self.bn(x1) + if self.functional: + x = nn.functional.batch_norm(x1, self.bn.running_mean, self.bn.running_var, self.bn.weight, self.bn.bias, + training=self.bn.training, momentum=self.bn.momentum, eps=self.bn.eps) + else: + x = self.bn(x1) x = torch.relu(x) if not self.fold_applied: x = x + x1 @@ -42,13 +48,14 @@ class BNFoldingNetTest(BasePytorchTest): """ This test checks the BatchNorm folding feature, plus adding a residual connection. """ - def __init__(self, unit_test, test_layer, fold_applied=True, float_reconstruction_error=1e-6): + def __init__(self, unit_test, test_layer, functional, fold_applied=True, float_reconstruction_error=1e-6): super().__init__(unit_test, float_reconstruction_error) self.test_layer = test_layer self.fold_applied = fold_applied + self.functional = functional def create_feature_network(self, input_shape): - return BNFoldingNet(self.test_layer, self.fold_applied) + return BNFoldingNet(self.test_layer, self.functional, self.fold_applied) def get_tpc(self): return {'no_quantization': super().get_tpc()['no_quantization']} @@ -63,29 +70,29 @@ def compare(self, quantized_models, float_model, input_x=None, quantization_info out_float = torch_tensor_to_numpy(float_model(*input_x)) out_quant = torch_tensor_to_numpy(quant_model(*input_x)) - is_bn_in_model = torch.nn.BatchNorm2d in [type(module) for name, module in quant_model.named_modules()] + is_bn_in_model = nn.BatchNorm2d in [type(module) for name, module in quant_model.named_modules()] self.unit_test.assertTrue(self.fold_applied is not is_bn_in_model) self.unit_test.assertTrue(np.isclose(out_quant, out_float, atol=1e-6, rtol=1e-4).all()) -class BNForwardFoldingNet(torch.nn.Module): +class BNForwardFoldingNet(nn.Module): def __init__(self, test_layer, add_bn=False, is_dw=False): super(BNForwardFoldingNet, self).__init__() if is_dw: - self.bn = torch.nn.Conv2d(3, 3, 1, groups=3) + self.bn = nn.Conv2d(3, 3, 1, groups=3) else: - self.bn = torch.nn.BatchNorm2d(3) - torch.nn.init.uniform_(self.bn.weight, 0.02, 1.05) - torch.nn.init.uniform_(self.bn.bias, -1.2, 1.05) - torch.nn.init.uniform_(self.bn.running_var, 0.02, 1.05) - torch.nn.init.uniform_(self.bn.running_mean, -1.2, 1.05) + self.bn = nn.BatchNorm2d(3) + nn.init.uniform_(self.bn.weight, 0.02, 1.05) + nn.init.uniform_(self.bn.bias, -1.2, 1.05) + nn.init.uniform_(self.bn.running_var, 0.02, 1.05) + nn.init.uniform_(self.bn.running_mean, -1.2, 1.05) self.conv = test_layer if add_bn: - self.bn2 = torch.nn.BatchNorm2d(test_layer.out_channels) - torch.nn.init.uniform_(self.bn2.weight, 0.02, 1.05) - torch.nn.init.uniform_(self.bn2.bias, -1.2, 1.05) - torch.nn.init.uniform_(self.bn2.running_var, 0.02, 1.05) - torch.nn.init.uniform_(self.bn2.running_mean, -1.2, 1.05) + self.bn2 = nn.BatchNorm2d(test_layer.out_channels) + nn.init.uniform_(self.bn2.weight, 0.02, 1.05) + nn.init.uniform_(self.bn2.bias, -1.2, 1.05) + nn.init.uniform_(self.bn2.running_var, 0.02, 1.05) + nn.init.uniform_(self.bn2.running_mean, -1.2, 1.05) else: self.bn2 = None @@ -106,6 +113,7 @@ class BNForwardFoldingNetTest(BasePytorchTest): def __init__(self, unit_test, test_layer, fold_applied=True, add_bn=False, is_dw=False): super().__init__(unit_test, float_reconstruction_error=1e-6, val_batch_size=2) self.test_layer = test_layer + self.bn_layer = nn.BatchNorm2d self.fold_applied = fold_applied self.add_bn = add_bn self.is_dw = is_dw @@ -125,10 +133,10 @@ def compare(self, quantized_models, float_model, input_x=None, quantization_info set_model(quant_model) if self.is_dw: - is_bn_in_model = (sum([type(module) is torch.nn.Conv2d for name, module in float_model.named_modules()]) == - sum([type(module) is torch.nn.Conv2d for name, module in quant_model.named_modules()])) + is_bn_in_model = (sum([type(module) is nn.Conv2d for name, module in float_model.named_modules()]) == + sum([type(module) is nn.Conv2d for name, module in quant_model.named_modules()])) else: - is_bn_in_model = torch.nn.BatchNorm2d in [type(module) for name, module in quant_model.named_modules()] + is_bn_in_model = nn.BatchNorm2d in [type(module) for name, module in quant_model.named_modules()] self.unit_test.assertTrue(self.fold_applied is not is_bn_in_model) diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index b5d100471..e3249165e 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -13,6 +13,9 @@ # limitations under the License. # ============================================================================== import unittest +from functools import partial + +import torch from torch import nn import model_compression_toolkit as mct from model_compression_toolkit.gptq.common.gptq_config import RoundingType @@ -23,7 +26,8 @@ MixedPrecisionBopsAndWeightsKPITest, MixedPrecisionBopsAndActivationKPITest, MixedPrecisionBopsAndTotalKPITest, \ MixedPrecisionBopsWeightsActivationKPITest, MixedPrecisionBopsMultipleOutEdgesTest from tests.pytorch_tests.model_tests.feature_models.qat_test import QuantizationAwareTrainingTest, \ - QuantizationAwareTrainingMixedPrecisionCfgTest, QuantizationAwareTrainingMixedPrecisionKpiCfgTest, QuantizationAwareTrainingQuantizerHolderTest + QuantizationAwareTrainingMixedPrecisionCfgTest, QuantizationAwareTrainingMixedPrecisionKpiCfgTest, \ + QuantizationAwareTrainingQuantizerHolderTest from tests.pytorch_tests.model_tests.feature_models.relu_replacement_test import SingleLayerReplacementTest, \ ReluReplacementTest, ReluReplacementWithAddBiasTest from tests.pytorch_tests.model_tests.feature_models.remove_assert_test import AssertNetTest @@ -32,7 +36,8 @@ from tests.pytorch_tests.model_tests.feature_models.bn_folding_test import BNFoldingNetTest, BNForwardFoldingNetTest from tests.pytorch_tests.model_tests.feature_models.linear_collapsing_test import TwoConv2DCollapsingTest, \ ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest -from tests.pytorch_tests.model_tests.feature_models.residual_collapsing_test import ResidualCollapsingTest1, ResidualCollapsingTest2 +from tests.pytorch_tests.model_tests.feature_models.residual_collapsing_test import ResidualCollapsingTest1, \ + ResidualCollapsingTest2 from tests.pytorch_tests.model_tests.feature_models.dynamic_size_inputs_test import ReshapeNetTest from tests.pytorch_tests.model_tests.feature_models.mixed_precision_activation_test import \ MixedPercisionActivationSearch8Bit, MixedPercisionActivationSearch2Bit, MixedPercisionActivationSearch4Bit, \ @@ -71,7 +76,8 @@ from tests.pytorch_tests.model_tests.feature_models.split_concat_net_test import SplitConcatNetTest from tests.pytorch_tests.model_tests.feature_models.torch_tensor_attr_net_test import TorchTensorAttrNetTest from tests.pytorch_tests.model_tests.feature_models.bn_function_test import BNFNetTest -from tests.pytorch_tests.model_tests.feature_models.gptq_test import GPTQAccuracyTest, GPTQWeightsUpdateTest, GPTQLearnRateZeroTest +from tests.pytorch_tests.model_tests.feature_models.gptq_test import GPTQAccuracyTest, GPTQWeightsUpdateTest, \ + GPTQLearnRateZeroTest from tests.pytorch_tests.model_tests.feature_models.uniform_activation_test import \ UniformActivationTest from tests.pytorch_tests.model_tests.feature_models.old_api_test import OldApiTest @@ -132,12 +138,14 @@ def test_bn_folding(self): """ This test checks the BatchNorm folding feature. """ - BNFoldingNetTest(self, nn.Conv2d(3, 2, kernel_size=1)).run_test() - BNFoldingNetTest(self, nn.Conv2d(3, 3, kernel_size=3, groups=3)).run_test() # DW-Conv test - BNFoldingNetTest(self, nn.ConvTranspose2d(3, 2, kernel_size=(2, 1))).run_test() - BNFoldingNetTest(self, nn.Conv2d(3, 2, kernel_size=2), fold_applied=False).run_test() - BNFoldingNetTest(self, nn.Conv2d(3, 3, kernel_size=(3, 1), groups=3), fold_applied=False).run_test() # DW-Conv test - BNFoldingNetTest(self, nn.ConvTranspose2d(3, 2, kernel_size=(1, 3)), fold_applied=False).run_test() + for functional in [True, False]: + BNFoldingNetTest(self, nn.Conv2d(3, 2, kernel_size=1), functional).run_test() + BNFoldingNetTest(self, nn.Conv2d(3, 3, kernel_size=3, groups=3), functional).run_test() # DW-Conv test + BNFoldingNetTest(self, nn.ConvTranspose2d(3, 2, kernel_size=(2, 1)), functional).run_test() + BNFoldingNetTest(self, nn.Conv2d(3, 2, kernel_size=2), functional, fold_applied=False).run_test() + BNFoldingNetTest(self, nn.Conv2d(3, 3, kernel_size=(3, 1), groups=3), + functional, fold_applied=False).run_test() # DW-Conv test + BNFoldingNetTest(self, nn.ConvTranspose2d(3, 2, kernel_size=(1, 3)), functional, fold_applied=False).run_test() def test_bn_forward_folding(self): """ @@ -147,7 +155,8 @@ def test_bn_forward_folding(self): BNForwardFoldingNetTest(self, nn.Conv2d(3, 3, 1, groups=3), is_dw=True).run_test() # DW-Conv test BNForwardFoldingNetTest(self, nn.ConvTranspose2d(3, 2, 1), is_dw=True).run_test() BNForwardFoldingNetTest(self, nn.Conv2d(3, 2, 2), fold_applied=False, is_dw=True).run_test() - BNForwardFoldingNetTest(self, nn.Conv2d(3, 3, (3, 1), groups=3), fold_applied=False, is_dw=True).run_test() # DW-Conv test + BNForwardFoldingNetTest(self, nn.Conv2d(3, 3, (3, 1), groups=3), fold_applied=False, + is_dw=True).run_test() # DW-Conv test BNForwardFoldingNetTest(self, nn.ConvTranspose2d(3, 2, (1, 3)), fold_applied=False, is_dw=True).run_test() BNForwardFoldingNetTest(self, nn.Conv2d(3, 2, 1), add_bn=True, is_dw=True).run_test() @@ -468,7 +477,8 @@ def test_gptq(self): GPTQAccuracyTest(self, per_channel=True, hessian_weights=False).run_test() GPTQAccuracyTest(self, per_channel=True, log_norm_weights=False).run_test() GPTQWeightsUpdateTest(self).run_test() - GPTQLearnRateZeroTest(self, experimental_exporter=False).run_test() # TODO: check why weights are different between gptq and ptq when using experimental exporter flag. May be due to different quantization ways (fake-quant pytorch layer vs numpy quantinzation in core/common/quantization/quantizers/quantizers_helpers) + GPTQLearnRateZeroTest(self, + experimental_exporter=False).run_test() # TODO: check why weights are different between gptq and ptq when using experimental exporter flag. May be due to different quantization ways (fake-quant pytorch layer vs numpy quantinzation in core/common/quantization/quantizers/quantizers_helpers) GPTQAccuracyTest(self, rounding_type=RoundingType.SoftQuantizer).run_test() GPTQAccuracyTest(self, rounding_type=RoundingType.SoftQuantizer, per_channel=False, @@ -480,11 +490,17 @@ def test_gptq(self): GPTQWeightsUpdateTest(self, rounding_type=RoundingType.SoftQuantizer).run_test() GPTQLearnRateZeroTest(self, rounding_type=RoundingType.SoftQuantizer, experimental_exporter=False).run_test() - GPTQAccuracyTest(self, rounding_type=RoundingType.SoftQuantizer, weights_quant_method=QuantizationMethod.UNIFORM).run_test() - GPTQAccuracyTest(self, rounding_type=RoundingType.SoftQuantizer, weights_quant_method=QuantizationMethod.UNIFORM, per_channel=False, params_learning=False).run_test() - GPTQAccuracyTest(self, rounding_type=RoundingType.SoftQuantizer, weights_quant_method=QuantizationMethod.UNIFORM, + GPTQAccuracyTest(self, rounding_type=RoundingType.SoftQuantizer, + weights_quant_method=QuantizationMethod.UNIFORM).run_test() + GPTQAccuracyTest(self, rounding_type=RoundingType.SoftQuantizer, + weights_quant_method=QuantizationMethod.UNIFORM, per_channel=False, + params_learning=False).run_test() + GPTQAccuracyTest(self, rounding_type=RoundingType.SoftQuantizer, + weights_quant_method=QuantizationMethod.UNIFORM, per_channel=True, hessian_weights=True, log_norm_weights=True, scaled_log_norm=True).run_test() - GPTQWeightsUpdateTest(self, rounding_type=RoundingType.SoftQuantizer, weights_quant_method=QuantizationMethod.UNIFORM, params_learning=False).run_test() #TODO: When params learning is True, the uniform quantizer gets a min value > max value + GPTQWeightsUpdateTest(self, rounding_type=RoundingType.SoftQuantizer, + weights_quant_method=QuantizationMethod.UNIFORM, + params_learning=False).run_test() # TODO: When params learning is True, the uniform quantizer gets a min value > max value def test_qat(self): """