Skip to content

Commit

Permalink
batch norm bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yardeny-sony committed Sep 25, 2024
1 parent 0648665 commit 6e5bec9
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions tests/pytorch_tests/model_tests/base_pytorch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,9 @@ def compare(self, quantized_models, float_model, input_x=None, quantization_info
# Check if we have a BatchNorm or MultiheadAttention layer in the model.
# If so, the outputs will not be the same, since the sqrt function in the
# Decomposition is not exactly like the sqrt in the C implementation of PyTorch.
# float_model_operators = [type(module) for name, module in float_model.named_modules()]
# if torch.nn.BatchNorm2d or torch.nn.MultiheadAttention in float_model_operators\
# or self.use_fuzzy_validation: # todo: add flag to batch norm and MHA
if torch.nn.BatchNorm2d or torch.nn.MultiheadAttention in [type(module) for name, module in float_model.named_modules()]:
float_model_operators = [type(module) for name, module in float_model.named_modules()]
if (torch.nn.BatchNorm2d in float_model_operators or
torch.nn.MultiheadAttention in float_model_operators or self.use_fuzzy_validation):
self.unit_test.assertTrue(np.all(np.isclose(torch_tensor_to_numpy(f), torch_tensor_to_numpy(q),
atol=self.float_reconstruction_error)))
else:
Expand Down

0 comments on commit 6e5bec9

Please sign in to comment.