diff --git a/tests/pytorch_tests/model_tests/feature_models/linear_function_test.py b/tests/pytorch_tests/model_tests/feature_models/linear_function_test.py index a14fa25b4..30dae8e22 100644 --- a/tests/pytorch_tests/model_tests/feature_models/linear_function_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/linear_function_test.py @@ -26,11 +26,13 @@ class LinearFNet(torch.nn.Module): def __init__(self): super(LinearFNet, self).__init__() self.fc1 = torch.nn.Linear(in_features=1000, out_features=100, bias=False) - self.fc2 = torch.nn.Linear(in_features=100, out_features=10, bias=True) + self.fc2 = torch.nn.Linear(in_features=100, out_features=50, bias=True) + self.fc3 = torch.nn.Linear(in_features=50, out_features=10, bias=False) def forward(self, x): x = F.linear(x, self.fc1.weight, self.fc1.bias) - y = F.linear(x, bias=self.fc2.bias, weight=self.fc2.weight) + x = F.linear(x, bias=self.fc2.bias, weight=self.fc2.weight) + y = F.linear(x, self.fc3.weight, bias=None) return y