From 0eac225e3fa2062fcd12541ce3faf8a60934d66d Mon Sep 17 00:00:00 2001 From: Elad Cohen <78862769+elad-c@users.noreply.github.com> Date: Sun, 10 Dec 2023 15:49:13 +0200 Subject: [PATCH] Fix activation in matmul sub (#884) 1. Fix MatMul->Dense substitution to include an activation attribute 2. Add warning to ActivationDecomposition substitution if node's activation attribute is missing --- .../substitutions/activation_decomposition.py | 6 ++++++ .../substitutions/matmul_substitution.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py index d7c8cecb9..a22d0cb7b 100644 --- a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +++ b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py @@ -16,6 +16,7 @@ from tensorflow.keras.layers import Dense, DepthwiseConv2D, Conv2D, Conv2DTranspose, Activation, SeparableConv2D +from model_compression_toolkit.logger import Logger from model_compression_toolkit.core import common from model_compression_toolkit.constants import FLOAT_32, DATA_TYPE from model_compression_toolkit.core.common.graph.base_graph import Graph @@ -62,6 +63,11 @@ def substitute(self, Graph after applying the substitution. """ + if ACTIVATION not in op2d_node.framework_attr: + Logger.warning(f'Op2d node {op2d_node.name} of type {op2d_node.type} is missing an "{ACTIVATION}"' + f' attribute -> Skipping substitution ActivationDecomposition') # pragma: no cover + return graph + activation_node_name = op2d_node.name + '_post_activation' activation_fw_attr = { diff --git a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py index a376f8752..b1fb12af6 100644 --- a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py +++ b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py @@ -22,7 +22,7 @@ from model_compression_toolkit.core.common.graph.base_node import BaseNode from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode from model_compression_toolkit.core.keras.constants import TRANSPOSE_A, TRANSPOSE_B, \ - ADJOINT_A, ADJOINT_B, UNITS, USE_BIAS, KERNEL + ADJOINT_A, ADJOINT_B, UNITS, USE_BIAS, KERNEL, ACTIVATION, LINEAR class MatmulToDenseSubstitution(common.BaseSubstitution): @@ -89,7 +89,7 @@ def substitute(self, w = w.transpose() dense_node = BaseNode(matmul_node.name, - {UNITS: w.shape[1], USE_BIAS: False}, + {UNITS: w.shape[1], USE_BIAS: False, ACTIVATION: LINEAR}, matmul_node.input_shape, matmul_node.output_shape, {KERNEL: w}, tf.keras.layers.Dense, reuse=matmul_node.reuse, reuse_group=matmul_node.reuse_group)