diff --git a/model_compression_toolkit/core/common/framework_implementation.py b/model_compression_toolkit/core/common/framework_implementation.py index ec0f64979..f7c2f5171 100644 --- a/model_compression_toolkit/core/common/framework_implementation.py +++ b/model_compression_toolkit/core/common/framework_implementation.py @@ -235,6 +235,14 @@ def get_linear_collapsing_substitution(self) -> common.BaseSubstitution: raise NotImplemented(f'{self.__class__.__name__} have to implement the ' f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover + @abstractmethod + def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution: + """ + Returns: conv2d add const collapsing substitution + """ + raise NotImplemented(f'{self.__class__.__name__} have to implement the ' + f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover + @abstractmethod def get_substitutions_statistics_correction(self, quant_config: QuantizationConfig) -> \ List[common.BaseSubstitution]: diff --git a/model_compression_toolkit/core/common/graph/base_node.py b/model_compression_toolkit/core/common/graph/base_node.py index e7757af3a..1feceafad 100644 --- a/model_compression_toolkit/core/common/graph/base_node.py +++ b/model_compression_toolkit/core/common/graph/base_node.py @@ -79,7 +79,8 @@ def __init__(self, def type(self): """ A function to get the node's layer_class op for convenient comparison - :return: the node's layer_class + Returns: + the node's layer_class """ return self.layer_class @@ -130,6 +131,14 @@ def __repr__(self): """ return f'{self.type.__name__}:{self.name}' + def is_reused(self) -> bool: + """ + Check whether the node is reused or not + Returns: + True if node is reused, else False + """ + return self.reuse or self.reuse_group is not None + def get_weights_by_keys(self, name: str) -> np.ndarray: """ Get a node's weight by its name. diff --git a/model_compression_toolkit/core/common/substitutions/batchnorm_folding.py b/model_compression_toolkit/core/common/substitutions/batchnorm_folding.py index fd347d25b..7d7383c1c 100644 --- a/model_compression_toolkit/core/common/substitutions/batchnorm_folding.py +++ b/model_compression_toolkit/core/common/substitutions/batchnorm_folding.py @@ -93,7 +93,7 @@ def substitute(self, # If the linear operator is part of a reused group (it is the "base" node, or a reused node), # we should skip the substitution. - if conv_node.reuse or conv_node.reuse_group is not None: + if conv_node.is_reused(): return graph bn_node = edge_nodes[1] @@ -230,7 +230,7 @@ def substitute(self, # If the linear operator is part of a reused group (it is the "base" node, or a reused node), # we should skip the substitution. - if conv_node.reuse or conv_node.reuse_group is not None or bn_node.reuse or bn_node.reuse_group is not None: + if conv_node.is_reused() or bn_node.is_reused(): return graph if len(graph.get_next_nodes(bn_node)) > 1 or len(graph.get_prev_nodes(conv_node)) > 1: diff --git a/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py b/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py index abba4b435..861bb5d3a 100644 --- a/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +++ b/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py @@ -79,7 +79,7 @@ def substitute(self, # If the linear operator is part of a reused group (it is the "base" node, or a reused node), # we should skip the substitution. - if source_node.reuse or source_node.reuse_group is not None: + if source_node.is_reused(): for qc in source_node.candidates_quantization_cfg: qc.weights_quantization_cfg.weights_second_moment_correction = False return graph diff --git a/model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py b/model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py index 7be6a0f2f..d8abf3765 100644 --- a/model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +++ b/model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py @@ -102,7 +102,7 @@ def substitute(self, # If the linear operator is part of a reused group (it is the "base" node, or a reused node), # we should skip the substitution. - if source_node.reuse or source_node.reuse_group is not None: + if source_node.is_reused(): Logger.exception("If the linear operator is part of a reused group we should skip the the BN folding " "substitution and SMC feature") # pragma: no cover diff --git a/model_compression_toolkit/core/common/substitutions/linear_collapsing.py b/model_compression_toolkit/core/common/substitutions/linear_collapsing.py index 5f0abbbfd..d1395a6e6 100644 --- a/model_compression_toolkit/core/common/substitutions/linear_collapsing.py +++ b/model_compression_toolkit/core/common/substitutions/linear_collapsing.py @@ -91,14 +91,11 @@ def substitute(self, Graph after applying the substitution. """ - first_node = edge_nodes[0] - second_node = edge_nodes[1] + first_node, second_node, _ = edge_nodes # If the linear operator is part of a reused group (it is the "base" node, or a reused node), # we should skip the substitution. - if first_node.reuse or first_node.reuse_group is not None: - return graph - if second_node.reuse or second_node.reuse_group is not None: + if first_node.is_reused() or second_node.is_reused(): return graph # If there is an extra connection between these two nodes skip the substitution @@ -182,3 +179,83 @@ def substitute(self, assert num_edges_before_substition - len(graph.edges) == 1 return graph + + +class Op2DAddConstCollapsing(common.BaseSubstitution): + """ + Collapse Add-const into preceding Op2D (Not non-linear activation between them) + """ + def __init__(self, + first_node: NodeOperationMatcher, + second_node: NodeOperationMatcher, + op2d_collapsing_fn: Callable, + bias_str: str, + use_bias_str: str, + layer_name_str: str = None): + """ + Collapsing Add-const node (2nd node) to Op2D node (first node). + Args: + first_node: Node matcher for Op2d type nodes. + second_node: Node matcher for add type nodes. + op2d_collapsing_fn: Function for updating the convolution kernel and bias + bias_str: The framework specific attribute name of the convolution layer's bias. + use_bias_str: The framework specific attribute name of the convolution layer's bias flag. + layer_name_str: The framework specific attribute name of layer's name. + """ + super().__init__(matcher_instance=EdgeMatcher(first_node, second_node)) + self.op2d_collapsing_fn = op2d_collapsing_fn + self.bias_str = bias_str + self.use_bias_str = use_bias_str + self.layer_name_str = layer_name_str + + def substitute(self, + graph: Graph, + edge_nodes: Tuple[BaseNode, BaseNode]) -> Graph: + """ + Collapse linear layer into preceding linear layers. + Convolution condition: + |-------------------------| |------| + | Op2D | ---> | Add-const | -> | Op2D | + |-------------------------| |------| + Args: + graph: Graph we apply the substitution on. + edge_nodes: Tuple of linear node and add nodes + Returns: + Graph after applying the substitution. + """ + + first_node, second_node, _ = edge_nodes + + # If the linear operator is part of a reused group (it is the "base" node, or a reused node), + # we should skip the substitution. + if first_node.is_reused() or second_node.is_reused(): + return graph + + # If there is an extra connection between these two nodes skip the substitution + if len(graph.get_next_nodes(first_node)) > 1 or len(graph.get_prev_nodes(second_node)) > 1: + return graph + + # New collapsed bias + bias = self.op2d_collapsing_fn(first_node, second_node, self.bias_str) + + # New collapsed node + op2d_collapsed = copy.deepcopy(first_node) + op2d_collapsed_name = first_node.name + '_collapsed' + op2d_collapsed.name = op2d_collapsed_name + op2d_collapsed.framework_attr[self.use_bias_str] = True + op2d_collapsed.set_weights_by_keys(self.bias_str, bias) + + if self.layer_name_str is not None: + op2d_collapsed.framework_attr[self.layer_name_str] = op2d_collapsed_name + + # Update graph + graph.add_node(op2d_collapsed) + graph.reconnect_out_edges(current_node=second_node, new_node=op2d_collapsed) + graph.reconnect_in_edges(current_node=first_node, new_node=op2d_collapsed) + graph.replace_output_node(current_node=second_node, new_node=op2d_collapsed) + + graph.remove_edge(first_node, second_node) + graph.remove_node(first_node) + graph.remove_node(second_node) + + return graph diff --git a/model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py b/model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py index 2e59d3ef5..24693bf08 100644 --- a/model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +++ b/model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py @@ -30,6 +30,9 @@ def linear_collapsing_substitute(graph: common.Graph, Returns: Transformed graph after applying all linear collapsing substitutions. """ + # TODO: remove this if after adding Op2d-add_const collapse substitution in PyTorch + if linear_collapsing_substitution is None: + return graph matched_nodes = graph.filter(linear_collapsing_substitution.matcher_instance) matched_nodes_list = [] match_indicator = True diff --git a/model_compression_toolkit/core/common/substitutions/residual_collapsing.py b/model_compression_toolkit/core/common/substitutions/residual_collapsing.py index 7eb57f509..1c7694a61 100644 --- a/model_compression_toolkit/core/common/substitutions/residual_collapsing.py +++ b/model_compression_toolkit/core/common/substitutions/residual_collapsing.py @@ -63,9 +63,7 @@ def substitute(self, # If the linear operator is part of a reused group (it is the "base" node, or a reused node), # we should skip the substitution. - if first_node.reuse or first_node.reuse_group is not None: - return graph - if second_node.reuse or second_node.reuse_group is not None: + if first_node.is_reused() or second_node.is_reused(): return graph # Check if convolution and residual satisfy the collapsing conditions, otherwise skip substitution diff --git a/model_compression_toolkit/core/graph_prep_runner.py b/model_compression_toolkit/core/graph_prep_runner.py index 4d405b727..397405889 100644 --- a/model_compression_toolkit/core/graph_prep_runner.py +++ b/model_compression_toolkit/core/graph_prep_runner.py @@ -129,6 +129,7 @@ def get_finalized_graph(initial_graph: Graph, transformed_graph = substitute(graph, fw_impl.get_substitutions_pre_statistics_collection(quant_config)) if quant_config.linear_collapsing: transformed_graph = linear_collapsing_substitute(transformed_graph, fw_impl.get_linear_collapsing_substitution()) + transformed_graph = linear_collapsing_substitute(transformed_graph, fw_impl.get_op2d_add_const_collapsing_substitution()) if quant_config.residual_collapsing: transformed_graph = substitute(transformed_graph, fw_impl.get_residual_collapsing_substitution()) diff --git a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py index dc72d5059..5d4463d59 100644 --- a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +++ b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py @@ -15,10 +15,14 @@ from typing import Tuple import numpy as np import tensorflow as tf -from tensorflow.keras.layers import Conv2D +if tf.__version__ < "2.6": + from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense +else: + from keras.layers import Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense + from model_compression_toolkit.core.common import BaseNode from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, NodeFrameworkAttrMatcher -from model_compression_toolkit.core.common.substitutions.linear_collapsing import Conv2DCollapsing +from model_compression_toolkit.core.common.substitutions.linear_collapsing import Conv2DCollapsing, Op2DAddConstCollapsing from model_compression_toolkit.core.keras.constants import KERNEL, KERNEL_SIZE, STRIDES, DILATIONS, LINEAR, \ ACTIVATION, BIAS, USE_BIAS, LAYER_NAME, FILTERS, PADDING, GROUPS, DATA_FORMAT from model_compression_toolkit.logger import Logger @@ -123,3 +127,69 @@ def keras_linear_collapsing() -> Conv2DCollapsing: FILTERS, data_format_str=DATA_FORMAT, layer_name_str=LAYER_NAME) + + +def op2d_add_const_collapsing_node_matchers() -> Tuple[NodeOperationMatcher, NodeOperationMatcher]: + """ + Function generates matchers for matching: + (Op2D, Add(const)) -> Op2D. (Op2D is one of [DepthwiseConv2D, Conv2D, Conv2DTranspose, Dense) + Returns: + Matcher for Op2D followed by Add const + """ + first_node = NodeOperationMatcher(DepthwiseConv2D) | \ + NodeOperationMatcher(Conv2D) | \ + NodeOperationMatcher(Conv2DTranspose) | \ + NodeOperationMatcher(Dense) + second_node = NodeOperationMatcher(tf.math.add) + return first_node, second_node + + +def op2d_add_const_collapsing_fn(op2d_node: BaseNode, + add_node: BaseNode, + bias_str: str) -> np.ndarray: + """ + Collapsing Add-Const to previous node's bias + Args: + op2d_node: Op2d layer node + add_node: Add layer to collapse + bias_str: The framework specific attribute name of the convolution layer's bias. + Returns: + The modified conv layer node's bias + """ + bias = op2d_node.get_weights_by_keys(bias_str) + + # read constant from add node + if len(add_node.op_call_args) > 0: + const = add_node.op_call_args[0] + elif 'y' in add_node.op_call_kwargs: + const = add_node.op_call_kwargs['y'] + else: + Logger.error(f'Unable to read constant from add node: {add_node.name}') # pragma: no cover + + # convert constant to numpy array + if isinstance(const, tf.Tensor): + const = const.numpy() + elif isinstance(const, list): + const = np.array(const) + else: + Logger.error(f'Unable to convert constant to numpy array: {add_node.name}') # pragma: no cover + + # return new bias + if bias is None: + return const + else: + return const + bias + + +def keras_op2d_add_const_collapsing() -> Op2DAddConstCollapsing: + """ + Returns: + An Op2DCollapsing initialized for Keras models. + """ + first_node, second_node = op2d_add_const_collapsing_node_matchers() + return Op2DAddConstCollapsing(first_node, + second_node, + op2d_add_const_collapsing_fn, + BIAS, + USE_BIAS, + layer_name_str=LAYER_NAME) diff --git a/model_compression_toolkit/core/keras/keras_implementation.py b/model_compression_toolkit/core/keras/keras_implementation.py index 39a8db930..dac938ee5 100644 --- a/model_compression_toolkit/core/keras/keras_implementation.py +++ b/model_compression_toolkit/core/keras/keras_implementation.py @@ -75,7 +75,7 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.batchnorm_refusing import \ keras_batchnorm_refusing from model_compression_toolkit.core.keras.graph_substitutions.substitutions.linear_collapsing import \ - keras_linear_collapsing + keras_linear_collapsing, keras_op2d_add_const_collapsing from model_compression_toolkit.core.keras.graph_substitutions.substitutions.residual_collapsing import \ keras_residual_collapsing from model_compression_toolkit.core.keras.graph_substitutions.substitutions.input_scaling import InputScaling, \ @@ -311,6 +311,12 @@ def get_linear_collapsing_substitution(self) -> common.BaseSubstitution: """ return keras_linear_collapsing() + def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution: + """ + Returns: Op2d add-const collapsing substitution + """ + return keras_op2d_add_const_collapsing() + def get_substitutions_post_statistics_collection(self, quant_config: QuantizationConfig) \ -> List[common.BaseSubstitution]: """ diff --git a/model_compression_toolkit/core/pytorch/pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pytorch_implementation.py index 6a39affa9..5947c01c4 100644 --- a/model_compression_toolkit/core/pytorch/pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pytorch_implementation.py @@ -289,6 +289,12 @@ def get_linear_collapsing_substitution(self) -> common.BaseSubstitution: """ return pytorch_linear_collapsing() + def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution: + """ + Returns: None, as Op2d add-const substitution is not supported in torch yet + """ + return None + def get_substitutions_post_statistics_collection(self, quant_config: QuantizationConfig) -> List[common.BaseSubstitution]: """ diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/linear_collapsing_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/linear_collapsing_test.py index 5288e7e46..05169b0c6 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/linear_collapsing_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/linear_collapsing_test.py @@ -14,9 +14,13 @@ # ============================================================================== from abc import ABC +from packaging import version import model_compression_toolkit as mct import tensorflow as tf -from tensorflow.keras.layers import Conv2D +if version.parse(tf.__version__) >= version.parse("2.13"): + from keras.src.layers.core import TFOpLambda +else: + from keras.layers.core import TFOpLambda from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model @@ -51,7 +55,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= y = float_model.predict(input_x) y_hat = quantized_model.predict(input_x) self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') - self.unit_test.assertTrue(len([l for l in quantized_model.layers if isinstance(l, KerasTrainableQuantizationWrapper) and isinstance(l.layer, Conv2D)]) < len([l for l in float_model.layers if isinstance(l, Conv2D)]), msg=f'fail number of layers should decrease!') + self.unit_test.assertTrue(len([l for l in quantized_model.layers if isinstance(l, KerasTrainableQuantizationWrapper) and isinstance(l.layer, layers.Conv2D)]) < len([l for l in float_model.layers if isinstance(l, layers.Conv2D)]), msg=f'fail number of layers should decrease!') cs = cosine_similarity(y, y_hat) self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}') @@ -69,7 +73,7 @@ def create_networks(self): def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): super().compare(quantized_model, float_model, input_x, quantization_info) for layer in quantized_model.layers: - if type(layer) == Conv2D: + if type(layer) == layers.Conv2D: self.unit_test.assertTrue(len(layer.weights) == 2, msg=f'fail Bias should appear in weights!!') class ThreeConv2DCollapsingTest(BaseConv2DCollapsingTest): @@ -86,7 +90,7 @@ def create_networks(self): def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): super().compare(quantized_model, float_model, input_x, quantization_info) for layer in quantized_model.layers: - if type(layer) == Conv2D: + if type(layer) == layers.Conv2D: self.unit_test.assertTrue(len(layer.weights) == 1,msg=f'fail Bias should not appear in weights!!') @@ -105,9 +109,10 @@ def create_networks(self): def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): super().compare(quantized_model, float_model, input_x, quantization_info) for layer in quantized_model.layers: - if type(layer) == Conv2D: + if type(layer) == layers.Conv2D: self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!') + class SixConv2DCollapsingTest(BaseConv2DCollapsingTest): def __init__(self, unit_test): super().__init__(unit_test) @@ -125,5 +130,113 @@ def create_networks(self): def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): super().compare(quantized_model, float_model, input_x, quantization_info) for layer in quantized_model.layers: - if type(layer) == Conv2D: - self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!') \ No newline at end of file + if type(layer) == layers.Conv2D: + self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!') + + +class Op2DAddConstCollapsingTest(BaseConv2DCollapsingTest): + def __init__(self, unit_test): + super().__init__(unit_test) + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + # ######## + # Cond2D # + # ######## + # Collapse Conv2D with bias + x = layers.Conv2D(filters=7, kernel_size=(5, 5), strides=(1, 1), padding='same', + use_bias=True, bias_initializer='glorot_uniform')(inputs) + x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype)) + x = layers.ReLU()(x) + + # Collapse Conv2D without bias, const first argument of tf.math.add + x = layers.Conv2D(filters=5, kernel_size=(5, 5), strides=(1, 1), padding='same', + use_bias=False)(x) + x = tf.math.add(tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype), x) + x = layers.ReLU()(x) + + # Collapse + operator to Conv2D without bias + # TODO: replace add with + (currently using tf.math.add because below TF 2.14 creates TFOpLambda which fails ths node matcher) + x = layers.Conv2D(filters=9, kernel_size=(5, 5), strides=(1, 1), padding='same', + use_bias=False)(x) + x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype)) + + # ################# + # DepthwiseConv2D # + # ################# + # Collapse DepthwiseConv2D with bias + x = layers.DepthwiseConv2D(kernel_size=(5, 5), strides=(1, 1), padding='same', + use_bias=True, bias_initializer='glorot_uniform')(x) + x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype)) + x = layers.ReLU()(x) + + # Collapse DepthwiseConv2D without bias, const first argument of tf.math.add + x = layers.DepthwiseConv2D(kernel_size=(5, 5), strides=(1, 1), padding='same', + use_bias=False)(x) + x = tf.math.add(tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype), x) + x = layers.ReLU()(x) + + # Collapse + operator to DepthwiseConv2D without bias + # TODO: replace add with + (currently using tf.math.add because below TF 2.14 creates TFOpLambda which fails ths node matcher) + x = layers.DepthwiseConv2D(kernel_size=(5, 5), strides=(1, 1), padding='same', + use_bias=False)(x) + x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype)) + + # ################# + # Conv2DTranspose # + # ################# + # Collapse Conv2DTranspose with bias + x = layers.Conv2DTranspose(filters=9, kernel_size=(5, 5), strides=(1, 1), padding='same', + use_bias=True, bias_initializer='glorot_uniform')(x) + x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype)) + x = layers.ReLU()(x) + + # Collapse Conv2DTranspose without bias, const first argument of tf.math.add + x = layers.Conv2DTranspose(filters=9, kernel_size=(5, 5), strides=(1, 1), padding='same', + use_bias=False)(x) + x = tf.math.add(tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype), x) + x = layers.ReLU()(x) + + # Collapse + operator to Conv2DTranspose without bias + # TODO: replace add with + (currently using tf.math.add because below TF 2.14 creates TFOpLambda which fails ths node matcher) + x = layers.Conv2DTranspose(filters=9, kernel_size=(5, 5), strides=(1, 1), padding='same', + use_bias=False)(x) + x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype)) + + # ####### + # Dense # + # ####### + x = layers.Reshape((-1,))(x) + # Collapse Dense with bias + x = layers.Dense(9, use_bias=True, bias_initializer='glorot_uniform')(x) + x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype)) + x = layers.ReLU()(x) + + # Collapse Dense without bias, const first argument of tf.math.add + x = layers.Dense(9, use_bias=False)(x) + x = tf.math.add(tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype), x) + x = layers.ReLU()(x) + + # Collapse + operator to Conv2DTranspose without bias + # TODO: replace add with + (currently using tf.math.add because below TF 2.14 creates TFOpLambda which fails ths node matcher) + x = layers.Dense(9, use_bias=False)(x) + x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype)) + + # Don't collapse + x2 = layers.Dense(9, use_bias=True, bias_initializer='glorot_uniform')(x) + x = tf.math.add(x2, x) + y = layers.ReLU()(x) + + return tf.keras.models.Model(inputs=inputs, outputs=y) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + super().compare(quantized_model, float_model, input_x, quantization_info) + num_adds = 0 + for layer in quantized_model.layers: + if type(layer) in [layers.Conv2D, layers.DepthwiseConv2D, layers.Conv2DTranspose, layers.Dense]: + self.unit_test.assertTrue(len(layer.weights) == 2, msg=f'fail Bias should appear in weights!!') + elif isinstance(layer, TFOpLambda) and layer.function is tf.add: + num_adds += 1 + + # check all "add"s were folded except the one with 2 tensor inputs + self.unit_test.assertTrue(num_adds == 1, msg=f'Only one add should remain in the quantized model') \ No newline at end of file diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 47ac728d3..bbe8e5235 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -45,7 +45,7 @@ from tests.keras_tests.feature_networks_tests.feature_networks.input_scaling_test import InputScalingDenseTest, \ InputScalingConvTest, InputScalingDWTest, InputScalingZeroPadTest from tests.keras_tests.feature_networks_tests.feature_networks.linear_collapsing_test import TwoConv2DCollapsingTest, \ - ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest + ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest from tests.keras_tests.feature_networks_tests.feature_networks.lut_quantizer import LUTWeightsQuantizerTest, \ LUTActivationQuantizerTest from tests.keras_tests.feature_networks_tests.feature_networks.mixed_precision_bops_test import \ @@ -531,6 +531,7 @@ def test_linear_collapsing(self): ThreeConv2DCollapsingTest(self).run_test() FourConv2DCollapsingTest(self).run_test() SixConv2DCollapsingTest(self).run_test() + Op2DAddConstCollapsingTest(self).run_test() def test_second_moment(self): DepthwiseConv2DSecondMomentTest(self).run_test()