From 6e5bec9192147f60142510cf43e62edef15c6130 Mon Sep 17 00:00:00 2001 From: yardeny-sony Date: Wed, 25 Sep 2024 10:46:53 +0300 Subject: [PATCH] batch norm bug fix --- tests/pytorch_tests/model_tests/base_pytorch_test.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/pytorch_tests/model_tests/base_pytorch_test.py b/tests/pytorch_tests/model_tests/base_pytorch_test.py index e612ca6cc..86ffbc6af 100644 --- a/tests/pytorch_tests/model_tests/base_pytorch_test.py +++ b/tests/pytorch_tests/model_tests/base_pytorch_test.py @@ -97,10 +97,9 @@ def compare(self, quantized_models, float_model, input_x=None, quantization_info # Check if we have a BatchNorm or MultiheadAttention layer in the model. # If so, the outputs will not be the same, since the sqrt function in the # Decomposition is not exactly like the sqrt in the C implementation of PyTorch. - # float_model_operators = [type(module) for name, module in float_model.named_modules()] - # if torch.nn.BatchNorm2d or torch.nn.MultiheadAttention in float_model_operators\ - # or self.use_fuzzy_validation: # todo: add flag to batch norm and MHA - if torch.nn.BatchNorm2d or torch.nn.MultiheadAttention in [type(module) for name, module in float_model.named_modules()]: + float_model_operators = [type(module) for name, module in float_model.named_modules()] + if (torch.nn.BatchNorm2d in float_model_operators or + torch.nn.MultiheadAttention in float_model_operators or self.use_fuzzy_validation): self.unit_test.assertTrue(np.all(np.isclose(torch_tensor_to_numpy(f), torch_tensor_to_numpy(q), atol=self.float_reconstruction_error))) else: