Skip to content

Commit

Permalink
Fix activation in matmul sub (sony#884)
Browse files Browse the repository at this point in the history
1. Fix MatMul->Dense substitution to include an activation attribute
2. Add warning to ActivationDecomposition substitution if node's activation attribute is missing
  • Loading branch information
elad-c authored Dec 10, 2023
1 parent fd740d2 commit 0eac225
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from tensorflow.keras.layers import Dense, DepthwiseConv2D, Conv2D, Conv2DTranspose, Activation, SeparableConv2D

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core import common
from model_compression_toolkit.constants import FLOAT_32, DATA_TYPE
from model_compression_toolkit.core.common.graph.base_graph import Graph
Expand Down Expand Up @@ -62,6 +63,11 @@ def substitute(self,
Graph after applying the substitution.
"""

if ACTIVATION not in op2d_node.framework_attr:
Logger.warning(f'Op2d node {op2d_node.name} of type {op2d_node.type} is missing an "{ACTIVATION}"'
f' attribute -> Skipping substitution ActivationDecomposition') # pragma: no cover
return graph

activation_node_name = op2d_node.name + '_post_activation'

activation_fw_attr = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
from model_compression_toolkit.core.keras.constants import TRANSPOSE_A, TRANSPOSE_B, \
ADJOINT_A, ADJOINT_B, UNITS, USE_BIAS, KERNEL
ADJOINT_A, ADJOINT_B, UNITS, USE_BIAS, KERNEL, ACTIVATION, LINEAR


class MatmulToDenseSubstitution(common.BaseSubstitution):
Expand Down Expand Up @@ -89,7 +89,7 @@ def substitute(self,
w = w.transpose()

dense_node = BaseNode(matmul_node.name,
{UNITS: w.shape[1], USE_BIAS: False},
{UNITS: w.shape[1], USE_BIAS: False, ACTIVATION: LINEAR},
matmul_node.input_shape, matmul_node.output_shape,
{KERNEL: w}, tf.keras.layers.Dense,
reuse=matmul_node.reuse, reuse_group=matmul_node.reuse_group)
Expand Down

0 comments on commit 0eac225

Please sign in to comment.