Skip to content

Commit

Permalink
add substitution for functional linear (#1266)
Browse files Browse the repository at this point in the history
add Pytorch substitution for functional linear and related tests
  • Loading branch information
itai-berman authored Nov 20, 2024
1 parent a4626e8 commit 4e76be8
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from torch import nn
import torch.nn.functional as F

from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.core.common import BaseNode, Graph, BaseSubstitution
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
from model_compression_toolkit.core.pytorch.constants import *
from model_compression_toolkit.logger import Logger


class FunctionalLinear(BaseSubstitution):
"""
Replace functional linear with Linear.
"""

def __init__(self):
"""
Matches: functional linear
"""
func_node = NodeOperationMatcher(F.linear)
super().__init__(matcher_instance=func_node)

def substitute(self,
graph: Graph,
func_node: FunctionalNode) -> Graph:
"""
Substitute functional.linear and its inputs with Linear.
Args:
graph: Graph we apply the substitution on.
node: node that match the pattern in the substitution init.
Returns:
Graph after applying the substitution.
"""

# Create new node of layer Linear
if 1 not in func_node.weights:
Logger.critical(f'Weight input missing for node {func_node.name}.') # pragma: no cover
# Extract index of kernel and bias according to tensor_input_allocs if they were input as kwargs. If
# they were input as args, use their fixed positions.
weight_index = func_node.tensor_input_allocs.index(KERNEL) if KERNEL in func_node.tensor_input_allocs else 1
bias_index = func_node.tensor_input_allocs.index(BIAS) if BIAS in func_node.tensor_input_allocs else 2
if weight_index not in func_node.weights:
Logger.critical(f'Mismatch between tensor_input_allocs and weight index in node {func_node.name}.') # pragma: no cover
weight = func_node.weights[weight_index]
bias = func_node.weights.get(bias_index)

framework_attr = {
IN_FEATURES: func_node.input_shape[0][-1],
OUT_FEATURES: func_node.output_shape[0][-1],
BIAS: bias is not None,
}

weights = {KERNEL: weight} if bias is None else {KERNEL: weight, BIAS: bias}

new_node = BaseNode(
name=func_node.name,
framework_attr=framework_attr,
input_shape=func_node.input_shape[0],
output_shape=func_node.output_shape,
weights=weights,
layer_class=nn.Linear,
has_activation=func_node.has_activation,
reuse=func_node.reuse,
reuse_group=func_node.reuse_group
)

graph.replace_node(func_node, new_node)
return graph
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
FunctionalBatchNorm
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.functional_layer_norm import \
FunctionalLayerNorm
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.functional_linear import \
FunctionalLinear
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.linear_collapsing import \
pytorch_linear_collapsing
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.multi_head_attention_decomposition \
Expand Down Expand Up @@ -266,6 +268,7 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
FunctionalConvSubstitution(fw_info),
FunctionalBatchNorm(),
FunctionalLayerNorm(),
FunctionalLinear(),
RemoveIdentity()]

def get_substitutions_pre_statistics_collection(self,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
import torch.nn.functional as F
from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device

"""
This test checks the linear functional substitution function.
"""


class LinearFNet(torch.nn.Module):
def __init__(self):
super(LinearFNet, self).__init__()
self.fc1 = torch.nn.Linear(in_features=1000, out_features=100, bias=False)
self.fc2 = torch.nn.Linear(in_features=100, out_features=50, bias=True)
self.fc3 = torch.nn.Linear(in_features=50, out_features=10, bias=False)

def forward(self, x):
x = F.linear(x, self.fc1.weight, self.fc1.bias)
x = F.linear(x, bias=self.fc2.bias, weight=self.fc2.weight)
y = F.linear(x, self.fc3.weight, bias=None)
return y


class LinearFNetTest(BasePytorchTest):
"""
This test check the linear functional substitution function.
"""

def __init__(self, unit_test):
super().__init__(unit_test)

def create_inputs_shape(self):
return [[self.val_batch_size, 1000]]

def create_feature_network(self, input_shape):
return LinearFNet()
7 changes: 7 additions & 0 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from tests.pytorch_tests.model_tests.feature_models.layer_norm_net_test import LayerNormNetTest
from tests.pytorch_tests.model_tests.feature_models.linear_collapsing_test import TwoConv2DCollapsingTest, \
ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest
from tests.pytorch_tests.model_tests.feature_models.linear_function_test import LinearFNetTest
from tests.pytorch_tests.model_tests.feature_models.lut_quantizer_test import LUTWeightsQuantizerTest, \
LUTActivationQuantizerTest
from tests.pytorch_tests.model_tests.feature_models.manual_bit_selection import ManualBitWidthByLayerTypeTest, \
Expand Down Expand Up @@ -239,6 +240,12 @@ def test_bn_function(self):
"""
BNFNetTest(self).run_test()

def test_linear_function(self):
"""
This test check the linear functional substitution function.
"""
LinearFNetTest(self).run_test()

def test_broken_net(self):
"""
This test checks that the "broken" node (node without output) is being
Expand Down

0 comments on commit 4e76be8

Please sign in to comment.