diff --git a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py index d7c8cecb9..a22d0cb7b 100644 --- a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +++ b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py @@ -16,6 +16,7 @@ from tensorflow.keras.layers import Dense, DepthwiseConv2D, Conv2D, Conv2DTranspose, Activation, SeparableConv2D +from model_compression_toolkit.logger import Logger from model_compression_toolkit.core import common from model_compression_toolkit.constants import FLOAT_32, DATA_TYPE from model_compression_toolkit.core.common.graph.base_graph import Graph @@ -62,6 +63,11 @@ def substitute(self, Graph after applying the substitution. """ + if ACTIVATION not in op2d_node.framework_attr: + Logger.warning(f'Op2d node {op2d_node.name} of type {op2d_node.type} is missing an "{ACTIVATION}"' + f' attribute -> Skipping substitution ActivationDecomposition') # pragma: no cover + return graph + activation_node_name = op2d_node.name + '_post_activation' activation_fw_attr = { diff --git a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py index a376f8752..b1fb12af6 100644 --- a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py +++ b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py @@ -22,7 +22,7 @@ from model_compression_toolkit.core.common.graph.base_node import BaseNode from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode from model_compression_toolkit.core.keras.constants import TRANSPOSE_A, TRANSPOSE_B, \ - ADJOINT_A, ADJOINT_B, UNITS, USE_BIAS, KERNEL + ADJOINT_A, ADJOINT_B, UNITS, USE_BIAS, KERNEL, ACTIVATION, LINEAR class MatmulToDenseSubstitution(common.BaseSubstitution): @@ -89,7 +89,7 @@ def substitute(self, w = w.transpose() dense_node = BaseNode(matmul_node.name, - {UNITS: w.shape[1], USE_BIAS: False}, + {UNITS: w.shape[1], USE_BIAS: False, ACTIVATION: LINEAR}, matmul_node.input_shape, matmul_node.output_shape, {KERNEL: w}, tf.keras.layers.Dense, reuse=matmul_node.reuse, reuse_group=matmul_node.reuse_group)