Skip to content

Commit

Permalink
BatchNorm Folding Bug fix (#2587)
Browse files Browse the repository at this point in the history
Signed-off-by: Sayanta Mukherjee <quic_ssayanta@quicinc.com>
  • Loading branch information
quic-ssayanta authored and quic-bharathr committed Sep 13, 2024
1 parent 25ebb11 commit f8ef91f
Showing 1 changed file with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,8 @@ def _delete_bn_from_functional(model: tf.keras.Model,
def wrapped_bn_layer_in_bns_to_remove(layer: tf.keras.layers.Layer) -> bool:
return isinstance(layer, QcQuantizeWrapper) and layer._layer_to_wrap in bn_layers_to_remove

tf.keras.backend.clear_session() # clear session to not have tensor name conflicts

# Step 1: Get the inbound and outbound connections for each layer in the model
model_layer_connections = ModelLayerConnections.get_model_layers_connection_properties(model)

Expand Down Expand Up @@ -492,7 +494,8 @@ def wrapped_bn_layer_in_bns_to_remove(layer: tf.keras.layers.Layer) -> bool:

KERAS_SYMBOLIC_TENSORS_INDEX = 0
# Check if we need to change layer_input order. If there is just one input, there is no order.
if isinstance(layer_input, List):
# Special case when there is a Lambda layer with multiple inputs is handled seperately
if isinstance(layer_input, List) and not isinstance(current_layer, TFOpLambda):
# Original models keras symbolic tensor order
original_keras_symbolic_tensors_order = model_layer_connections[ModelLayerConnectionsProperties.CALL_ARGS][
current_layer.name][KERAS_SYMBOLIC_TENSORS_INDEX]
Expand Down Expand Up @@ -550,7 +553,6 @@ def wrapped_bn_layer_in_bns_to_remove(layer: tf.keras.layers.Layer) -> bool:
if current_layer.name in model.output_names:
model_outputs.append(x)

tf.keras.backend.clear_session() # clear session to not have tensor name conflicts
return tf.keras.Model(inputs=model.inputs, outputs=model_outputs)


Expand Down

0 comments on commit f8ef91f

Please sign in to comment.