Skip to content

Commit

Permalink
Remove identity - Keras (sony#1047)
Browse files Browse the repository at this point in the history
* Add remove identety substitution for Keras.
* Use 'common' function in Pytorch substitution to avoid code duplication.

---------

Co-authored-by: reuvenp <reuvenp@altair-semi.com>
  • Loading branch information
reuvenperetz and reuvenp authored Apr 17, 2024
1 parent 9fa60d4 commit 6cbffa7
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from model_compression_toolkit.core.common.graph.base_graph import Graph
from model_compression_toolkit.core.common.graph.base_node import BaseNode


def remove_identity_node(graph: Graph,
node: BaseNode) -> Graph:
"""
The method to perform the substitution of the identity node by
reconnecting its input directly to its output, effectively removing the node
from the graph.
Args:
graph: The current graph of operations where the node resides.
node: The specific `BaseNode` that is matched to be an Identity operation.
Returns:
Graph: The updated graph after removing the identity node.
"""
# Retrieve the predecessor nodes of the identity node.
prev_identity_nodes = graph.get_prev_nodes(node)
# Ensure there is exactly one predecessor; otherwise, do nothing.
if len(prev_identity_nodes) != 1:
return graph

# Reconnect the output edges of the identity node to its predecessor,
# effectively bypassing the identity node.
graph.reconnect_out_edges(current_node=node, new_node=prev_identity_nodes[0])
# Remove the edge from the predecessor to the identity node.
graph.remove_edge(prev_identity_nodes[0], node)
# Remove the identity node from the graph.
graph.remove_node(node_to_remove=node)

return graph
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import keras
import tensorflow as tf

from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common.graph.base_graph import Graph
from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.common.substitutions.remove_identity import remove_identity_node


class RemoveIdentity(common.BaseSubstitution):
"""
Remove Identity layers from the graph.
"""

def __init__(self):
nodes = NodeOperationMatcher(keras.layers.Identity) | NodeOperationMatcher(tf.identity)
super().__init__(matcher_instance=nodes)

def substitute(self,
graph: Graph,
node: BaseNode) -> Graph:
"""
The method to perform the substitution of the identity keras node by
reconnecting its input directly to its output, effectively removing the node
from the graph.
Args:
graph: The current graph of operations where the node resides.
node: The specific `BaseNode` that is matched to be an Identity operation.
Returns:
Graph: The updated graph after removing the identity node.
"""
return remove_identity_node(graph, node)

4 changes: 3 additions & 1 deletion model_compression_toolkit/core/keras/keras_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, HessianInfoService
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.remove_identity import RemoveIdentity
from model_compression_toolkit.core.keras.hessian.activation_trace_hessian_calculator_keras import \
ActivationTraceHessianCalculatorKeras
from model_compression_toolkit.core.keras.hessian.weights_trace_hessian_calculator_keras import WeightsTraceHessianCalculatorKeras
Expand Down Expand Up @@ -246,7 +247,8 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
MatmulToDenseSubstitution(),
MultiHeadAttentionDecomposition(),
ActivationDecomposition(),
DwconvToConv()]
DwconvToConv(),
RemoveIdentity()]

def get_substitutions_pre_statistics_collection(self, quant_config: QuantizationConfig) -> \
List[common.BaseSubstitution]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from torch import reshape
import torch

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common.substitutions.remove_identity import remove_identity_node
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common.graph.base_graph import Graph
from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.pytorch.constants import BATCH_DIM_VALUE


class RemoveIdentity(common.BaseSubstitution):
Expand All @@ -47,20 +45,6 @@ def substitute(self,
Returns:
Graph: The updated graph after removing the identity node.
"""
return remove_identity_node(graph, node)

# Retrieve the predecessor nodes of the identity node.
prev_identity_nodes = graph.get_prev_nodes(node)
# Ensure there is exactly one predecessor; otherwise, do nothing.
if len(prev_identity_nodes) != 1:
return graph

# Reconnect the output edges of the identity node to its predecessor,
# effectively bypassing the identity node.
graph.reconnect_out_edges(current_node=node, new_node=prev_identity_nodes[0])
# Remove the edge from the predecessor to the identity node.
graph.remove_edge(prev_identity_nodes[0], node)
# Remove the identity node from the graph.
graph.remove_node(node_to_remove=node)

return graph

Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest
import keras
import tensorflow as tf

class RemoveIdentityTest(BaseKerasFeatureNetworkTest):
def __init__(self, unit_test):
super().__init__(unit_test)

def create_networks(self):
inputs = keras.layers.Input(shape=self.get_input_shapes()[0][1:])
x = keras.layers.Conv2D(3, 3)(inputs)
x = keras.layers.Identity()(x)
x = tf.identity(x)
outputs = keras.layers.BatchNormalization()(x)
return keras.Model(inputs=inputs, outputs=outputs)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
# Make sure identity and bn layers are not in the final model.
# there should be 4 layers: input, input_quantizer, conv, conv_quantizer
self.unit_test.assertTrue(len(quantized_model.layers)==4)

Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
QuantizationAwareTrainingQuantizerHolderTest
from tests.keras_tests.feature_networks_tests.feature_networks.relu_replacement_test import ReluReplacementTest, \
SingleReluReplacementTest, ReluReplacementWithAddBiasTest
from tests.keras_tests.feature_networks_tests.feature_networks.remove_identity_test import RemoveIdentityTest
from tests.keras_tests.feature_networks_tests.feature_networks.residual_collapsing_test import ResidualCollapsingTest1, \
ResidualCollapsingTest2
from tests.keras_tests.feature_networks_tests.feature_networks.reused_layer_mixed_precision_test import \
Expand Down Expand Up @@ -139,7 +140,10 @@


class FeatureNetworkTest(unittest.TestCase):


def test_remove_identity(self):
RemoveIdentityTest(self).run_test()

def test_per_tensor_weight_quantization(self):
PerTensorWeightQuantizationTest(self).run_test()

Expand Down

0 comments on commit 6cbffa7

Please sign in to comment.