Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Op2d->add_const collapse substitution #878

Merged
merged 3 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ def get_linear_collapsing_substitution(self) -> common.BaseSubstitution:
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover

@abstractmethod
def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: conv2d add const collapsing substitution
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover

@abstractmethod
def get_substitutions_statistics_correction(self, quant_config: QuantizationConfig) -> \
List[common.BaseSubstitution]:
Expand Down
11 changes: 10 additions & 1 deletion model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def __init__(self,
def type(self):
"""
A function to get the node's layer_class op for convenient comparison
:return: the node's layer_class
Returns:
the node's layer_class
"""
return self.layer_class

Expand Down Expand Up @@ -130,6 +131,14 @@ def __repr__(self):
"""
return f'{self.type.__name__}:{self.name}'

def is_reused(self) -> bool:
"""
Check whether the node is reused or not
Returns:
True if node is reused, else False
"""
return self.reuse or self.reuse_group is not None

def get_weights_by_keys(self, name: str) -> np.ndarray:
"""
Get a node's weight by its name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def substitute(self,

# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
# we should skip the substitution.
if conv_node.reuse or conv_node.reuse_group is not None:
if conv_node.is_reused():
return graph

bn_node = edge_nodes[1]
Expand Down Expand Up @@ -230,7 +230,7 @@ def substitute(self,

# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
# we should skip the substitution.
if conv_node.reuse or conv_node.reuse_group is not None or bn_node.reuse or bn_node.reuse_group is not None:
if conv_node.is_reused() or bn_node.is_reused():
return graph

if len(graph.get_next_nodes(bn_node)) > 1 or len(graph.get_prev_nodes(conv_node)) > 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def substitute(self,

# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
# we should skip the substitution.
if source_node.reuse or source_node.reuse_group is not None:
if source_node.is_reused():
for qc in source_node.candidates_quantization_cfg:
qc.weights_quantization_cfg.weights_second_moment_correction = False
return graph
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def substitute(self,

# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
# we should skip the substitution.
if source_node.reuse or source_node.reuse_group is not None:
if source_node.is_reused():
Logger.exception("If the linear operator is part of a reused group we should skip the the BN folding "
"substitution and SMC feature") # pragma: no cover

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,11 @@ def substitute(self,
Graph after applying the substitution.
"""

first_node = edge_nodes[0]
second_node = edge_nodes[1]
first_node, second_node, _ = edge_nodes

# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
# we should skip the substitution.
if first_node.reuse or first_node.reuse_group is not None:
return graph
if second_node.reuse or second_node.reuse_group is not None:
if first_node.is_reused() or second_node.is_reused():
return graph

# If there is an extra connection between these two nodes skip the substitution
Expand Down Expand Up @@ -182,3 +179,83 @@ def substitute(self,
assert num_edges_before_substition - len(graph.edges) == 1

return graph


class Op2DAddConstCollapsing(common.BaseSubstitution):
"""
Collapse Add-const into preceding Op2D (Not non-linear activation between them)
"""
def __init__(self,
first_node: NodeOperationMatcher,
second_node: NodeOperationMatcher,
op2d_collapsing_fn: Callable,
bias_str: str,
use_bias_str: str,
layer_name_str: str = None):
"""
Collapsing Add-const node (2nd node) to Op2D node (first node).
Args:
first_node: Node matcher for Op2d type nodes.
second_node: Node matcher for add type nodes.
op2d_collapsing_fn: Function for updating the convolution kernel and bias
bias_str: The framework specific attribute name of the convolution layer's bias.
use_bias_str: The framework specific attribute name of the convolution layer's bias flag.
layer_name_str: The framework specific attribute name of layer's name.
"""
super().__init__(matcher_instance=EdgeMatcher(first_node, second_node))
self.op2d_collapsing_fn = op2d_collapsing_fn
self.bias_str = bias_str
self.use_bias_str = use_bias_str
self.layer_name_str = layer_name_str

def substitute(self,
graph: Graph,
edge_nodes: Tuple[BaseNode, BaseNode]) -> Graph:
"""
Collapse linear layer into preceding linear layers.
Convolution condition:
|-------------------------| |------|
| Op2D | ---> | Add-const | -> | Op2D |
|-------------------------| |------|
Args:
graph: Graph we apply the substitution on.
edge_nodes: Tuple of linear node and add nodes
Returns:
Graph after applying the substitution.
"""

first_node, second_node, _ = edge_nodes

# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
# we should skip the substitution.
if first_node.is_reused() or second_node.is_reused():
return graph

# If there is an extra connection between these two nodes skip the substitution
if len(graph.get_next_nodes(first_node)) > 1 or len(graph.get_prev_nodes(second_node)) > 1:
return graph

# New collapsed bias
bias = self.op2d_collapsing_fn(first_node, second_node, self.bias_str)

# New collapsed node
op2d_collapsed = copy.deepcopy(first_node)
op2d_collapsed_name = first_node.name + '_collapsed'
reuvenperetz marked this conversation as resolved.
Show resolved Hide resolved
op2d_collapsed.name = op2d_collapsed_name
op2d_collapsed.framework_attr[self.use_bias_str] = True
op2d_collapsed.set_weights_by_keys(self.bias_str, bias)

if self.layer_name_str is not None:
op2d_collapsed.framework_attr[self.layer_name_str] = op2d_collapsed_name

# Update graph
graph.add_node(op2d_collapsed)
graph.reconnect_out_edges(current_node=second_node, new_node=op2d_collapsed)
graph.reconnect_in_edges(current_node=first_node, new_node=op2d_collapsed)
graph.replace_output_node(current_node=second_node, new_node=op2d_collapsed)

graph.remove_edge(first_node, second_node)
graph.remove_node(first_node)
graph.remove_node(second_node)

return graph
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def linear_collapsing_substitute(graph: common.Graph,
Returns:
Transformed graph after applying all linear collapsing substitutions.
"""
# TODO: remove this if after adding Op2d-add_const collapse substitution in PyTorch
if linear_collapsing_substitution is None:
elad-c marked this conversation as resolved.
Show resolved Hide resolved
return graph
matched_nodes = graph.filter(linear_collapsing_substitution.matcher_instance)
matched_nodes_list = []
match_indicator = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def substitute(self,

# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
# we should skip the substitution.
if first_node.reuse or first_node.reuse_group is not None:
return graph
if second_node.reuse or second_node.reuse_group is not None:
if first_node.is_reused() or second_node.is_reused():
return graph

# Check if convolution and residual satisfy the collapsing conditions, otherwise skip substitution
Expand Down
1 change: 1 addition & 0 deletions model_compression_toolkit/core/graph_prep_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def get_finalized_graph(initial_graph: Graph,
transformed_graph = substitute(graph, fw_impl.get_substitutions_pre_statistics_collection(quant_config))
if quant_config.linear_collapsing:
transformed_graph = linear_collapsing_substitute(transformed_graph, fw_impl.get_linear_collapsing_substitution())
transformed_graph = linear_collapsing_substitute(transformed_graph, fw_impl.get_op2d_add_const_collapsing_substitution())
if quant_config.residual_collapsing:
transformed_graph = substitute(transformed_graph, fw_impl.get_residual_collapsing_substitution())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
from typing import Tuple
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv2D
if tf.__version__ < "2.6":
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense
else:
from keras.layers import Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense

from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, NodeFrameworkAttrMatcher
from model_compression_toolkit.core.common.substitutions.linear_collapsing import Conv2DCollapsing
from model_compression_toolkit.core.common.substitutions.linear_collapsing import Conv2DCollapsing, Op2DAddConstCollapsing
from model_compression_toolkit.core.keras.constants import KERNEL, KERNEL_SIZE, STRIDES, DILATIONS, LINEAR, \
ACTIVATION, BIAS, USE_BIAS, LAYER_NAME, FILTERS, PADDING, GROUPS, DATA_FORMAT
from model_compression_toolkit.logger import Logger
Expand Down Expand Up @@ -123,3 +127,69 @@ def keras_linear_collapsing() -> Conv2DCollapsing:
FILTERS,
data_format_str=DATA_FORMAT,
layer_name_str=LAYER_NAME)


def op2d_add_const_collapsing_node_matchers() -> Tuple[NodeOperationMatcher, NodeOperationMatcher]:
"""
Function generates matchers for matching:
(Op2D, Add(const)) -> Op2D. (Op2D is one of [DepthwiseConv2D, Conv2D, Conv2DTranspose, Dense)
Returns:
Matcher for Op2D followed by Add const
"""
first_node = NodeOperationMatcher(DepthwiseConv2D) | \
NodeOperationMatcher(Conv2D) | \
NodeOperationMatcher(Conv2DTranspose) | \
NodeOperationMatcher(Dense)
second_node = NodeOperationMatcher(tf.math.add)
return first_node, second_node


def op2d_add_const_collapsing_fn(op2d_node: BaseNode,
add_node: BaseNode,
bias_str: str) -> np.ndarray:
"""
Collapsing Add-Const to previous node's bias
Args:
op2d_node: Op2d layer node
add_node: Add layer to collapse
bias_str: The framework specific attribute name of the convolution layer's bias.
Returns:
The modified conv layer node's bias
"""
bias = op2d_node.get_weights_by_keys(bias_str)

# read constant from add node
if len(add_node.op_call_args) > 0:
const = add_node.op_call_args[0]
elif 'y' in add_node.op_call_kwargs:
const = add_node.op_call_kwargs['y']
else:
Logger.error(f'Unable to read constant from add node: {add_node.name}') # pragma: no cover

# convert constant to numpy array
if isinstance(const, tf.Tensor):
const = const.numpy()
elif isinstance(const, list):
const = np.array(const)
else:
Logger.error(f'Unable to convert constant to numpy array: {add_node.name}') # pragma: no cover

# return new bias
if bias is None:
return const
else:
return const + bias


def keras_op2d_add_const_collapsing() -> Op2DAddConstCollapsing:
"""
Returns:
An Op2DCollapsing initialized for Keras models.
"""
first_node, second_node = op2d_add_const_collapsing_node_matchers()
return Op2DAddConstCollapsing(first_node,
second_node,
op2d_add_const_collapsing_fn,
BIAS,
USE_BIAS,
layer_name_str=LAYER_NAME)
8 changes: 7 additions & 1 deletion model_compression_toolkit/core/keras/keras_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.batchnorm_refusing import \
keras_batchnorm_refusing
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.linear_collapsing import \
keras_linear_collapsing
keras_linear_collapsing, keras_op2d_add_const_collapsing
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.residual_collapsing import \
keras_residual_collapsing
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.input_scaling import InputScaling, \
Expand Down Expand Up @@ -311,6 +311,12 @@ def get_linear_collapsing_substitution(self) -> common.BaseSubstitution:
"""
return keras_linear_collapsing()

def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: Op2d add-const collapsing substitution
"""
return keras_op2d_add_const_collapsing()

def get_substitutions_post_statistics_collection(self, quant_config: QuantizationConfig) \
-> List[common.BaseSubstitution]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,12 @@ def get_linear_collapsing_substitution(self) -> common.BaseSubstitution:
"""
return pytorch_linear_collapsing()

def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: None, as Op2d add-const substitution is not supported in torch yet
"""
return None

def get_substitutions_post_statistics_collection(self,
quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
"""
Expand Down
Loading
Loading