Skip to content

Commit

Permalink
Keras model preparer and BN fold bug fix (#2619)
Browse files Browse the repository at this point in the history
* Bug fix on model preparer and BN fold for keras

* Added function to update CALL_ARGS of model layer connections and _Merge import statement changed according to tf version

* Added typehints and docstring to the newly added function and fixed pylint issues.

Signed-off-by: Sayanta Mukherjee <quic_ssayanta@quicinc.com>
  • Loading branch information
quic-ssayanta authored Dec 26, 2023
1 parent 229ec6c commit 8b646fa
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

import aimet_common.libpymo as libpymo
from aimet_common.utils import AimetLogger
from aimet_tensorflow.keras.model_preparer import _handle_normal_keras_layer
from aimet_tensorflow.keras.model_preparer import _handle_normal_keras_layer, _update_output_tensors_in_model_layers_connections
from aimet_tensorflow.keras.quant_sim.qc_quantize_wrapper import QcQuantizeWrapper
from aimet_tensorflow.keras.quant_sim.tensor_quantizer import ParamPerTensorQuantizer
from aimet_tensorflow.keras.quantsim import QuantizationSimModel
Expand Down Expand Up @@ -542,6 +542,10 @@ def wrapped_bn_layer_in_bns_to_remove(layer: tf.keras.layers.Layer) -> bool:
# Special case for when there is a Lambda opertaion with multiple inputs. For example, z = x + y.
if isinstance(current_layer, TFOpLambda):
x = _handle_normal_keras_layer(current_layer, model_layer_connections)
current_layer._outbound_nodes = [] # pylint: disable=protected-access
# Updating the Model layer connections
_update_output_tensors_in_model_layers_connections(current_layer, x, model, model_layer_connections,
current_layer._outbound_nodes)
else:
x = current_layer(layer_input)
current_layer._outbound_nodes = [] # pylint: disable=protected-access
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@
import re
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.layers.merge import _Merge as MergeLayersParentClass
import tensorflow.keras.backend as K
import tensorflow.keras.backend as K # pylint: disable=ungrouped-imports
from packaging import version

if version.parse(tf.version.VERSION) >= version.parse("2.10"):
Expand All @@ -51,12 +50,14 @@
from keras.engine.functional import Functional # pylint: disable=import-error
from keras.engine.keras_tensor import KerasTensor # pylint: disable=import-error
from keras.layers.core.tf_op_layer import TFOpLambda # pylint: disable=import-error
from keras.layers.merging.base_merge import _Merge as MergeLayersParentClass # pylint: disable=ungrouped-imports
else:
# Ignore pylint errors due to conditional imports
from tensorflow.python.keras.engine.base_layer_utils import is_subclassed # pylint: disable=ungrouped-imports
from tensorflow.python.keras.engine.keras_tensor import KerasTensor # pylint: disable=ungrouped-imports
from tensorflow.python.keras.engine.functional import Functional # pylint: disable=ungrouped-imports
from tensorflow.python.keras.layers.core import TFOpLambda # pylint: disable=ungrouped-imports
from tensorflow.python.keras.layers.merge import _Merge as MergeLayersParentClass # pylint: disable=ungrouped-imports

# pylint: disable=wrong-import-position
from aimet_tensorflow.keras.utils.model_connection_utils import ModelLayerConnections, ModelLayerConnectionsProperties
Expand Down Expand Up @@ -328,13 +329,37 @@ def _get_call_kwargs(layer: tf.keras.layers.Layer, model_layers_connections: Mod
return {}
return call_kwargs

def _update_call_args_in_model_layer_connections(model_layer_connections: ModelLayerConnectionsProperties.TYPE,
layer: tf.keras.layers.Layer, new_output_tensor: KerasTensor):
"""
Helper function to update the call args in model layer connections dictionary.
:param model_layer_connections: The model layers connections dictionary
:param layer: The layer to update the output tensors of
:param new_output_tensor: The new output tensor to update with
"""
KERAS_SYMBOLIC_TENSORS_INDEX = 0
# pylint: disable=protected-access
for layer_name, keras_tensor in model_layer_connections[ModelLayerConnectionsProperties.CALL_ARGS].items():
keras_tensor = keras_tensor[KERAS_SYMBOLIC_TENSORS_INDEX]
if isinstance(keras_tensor, list):
for idx, each_keras_tensor in enumerate(keras_tensor):
if isinstance(each_keras_tensor, KerasTensor) and \
each_keras_tensor._keras_history.layer.name == layer.name:
model_layer_connections[ModelLayerConnectionsProperties.CALL_ARGS][layer_name]\
[KERAS_SYMBOLIC_TENSORS_INDEX][idx] = new_output_tensor
else:
if isinstance(keras_tensor, KerasTensor) and keras_tensor._keras_history.layer.name == layer.name:
model_layer_connections[ModelLayerConnectionsProperties.CALL_ARGS][layer_name] = (new_output_tensor,)


def _update_output_tensors_in_model_layers_connections(layer: tf.keras.layers.Layer, new_output_tensor: KerasTensor,
model: tf.keras.Model,
model_layers_connections: ModelLayerConnectionsProperties.TYPE,
model_outputs: List[KerasTensor]):
"""
Helper function to update the output tensors in the model layers connections dictionary.
Helper function to update the output tensors in the model layers connections dictionary. It also updates
the call_args of model layer connections
:param layer: The layer to update the output tensors of
:param new_output_tensor: The new output tensor to update with
Expand All @@ -361,6 +386,9 @@ def _update_output_tensors_in_model_layers_connections(layer: tf.keras.layers.La
model_layers_connections[ModelLayerConnectionsProperties.OUTPUT_TENSORS].update(
{layer.name: new_output_tensor})

# Updating the CALL_ARGS of model layer connections with new output tensor
_update_call_args_in_model_layer_connections(model_layers_connections, layer, new_output_tensor)

# Save tensor in output list if it is output in the initial model
# TODO: Update so that the last conditional is only checked when it's not the last layer.
if model.output_names and layer.name in model.output_names:
Expand Down

0 comments on commit 8b646fa

Please sign in to comment.