Skip to content

Commit

Permalink
Keras - Multi Output Causing ConnectedGraph Issues (#2645)
Browse files Browse the repository at this point in the history
* updated to handle multi output layers and fixes logic of num input quantizers


Signed-off-by: Matthew Ernst <quic_ernst@quicinc.com>
  • Loading branch information
quic-ernst authored Jan 10, 2024
1 parent 9ff81dc commit df3cb8c
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,12 @@ def _parse_layer(self, layer: tf.keras.layers.Layer):

op = self._generate_op(op_type, layer)

self._name_to_layer[layer.output.name] = layer
if isinstance(layer.output, typing.List):
for output in layer.output:
self._name_to_layer[output.name] = layer
else:
self._name_to_layer[layer.output.name] = layer

self._op_name_to_layer[op.name] = layer
self._layer_to_op[layer] = op
self._ops[op.name] = op
Expand Down Expand Up @@ -237,7 +242,9 @@ def _generate_usual_product(
:param inbound_layer: tf.keras.layer related to producer Op
"""
consumer_op = self._layer_to_op.get(target_layer)
producer_op = self.get_op_from_module_name(inbound_layer.output.name)
inbound_layer_output_name = \
inbound_layer.output[0].name if isinstance(inbound_layer.output, typing.List) else inbound_layer.output.name
producer_op = self.get_op_from_module_name(inbound_layer_output_name)

if producer_op is None:
raise RuntimeError("Producer Op must exist")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,18 @@ def call(self, inputs, *args, **kwargs):
name for name, tensor in kwargs.items()
if _is_keras_or_tensor_input(tensor) or isinstance(tensor, np.ndarray)
]

# TF functions like tf.concat could have two inputs in List form. But other layers could match
# The TFOpLambda where one input is in inputs and the other(s) are in a kwargs dict
num_inputs_to_quantize = len(inputs) if isinstance(inputs, List) else 1
# Quantize the input directly first
inputs = self._quantize_activation(inputs, [self.input_quantizers[0]], True)
inputs = self._quantize_activation(
inputs,
quantizers=self.input_quantizers[:num_inputs_to_quantize],
is_input_quantization=True
)
# Quantize any subsequent arguments
for tensor_name, input_quantizer in zip(kwargs_keys_for_keras_tensors, self.input_quantizers[1:]):
for tensor_name, input_quantizer in zip(kwargs_keys_for_keras_tensors, self.input_quantizers[num_inputs_to_quantize:]):
kwargs[tensor_name] = self._quantize_activation(kwargs[tensor_name], [input_quantizer], True)
else:
inputs = self._quantize_activation(inputs, self.input_quantizers, True)
Expand All @@ -340,7 +348,8 @@ def _quantize_params(self):
try:
idx_param_quantizer = 0
for idx, param in enumerate(self._layer_to_wrap.weights):
# check and break if idx_param_quantizer is out of range (Batchnorm fold will update bias tensor, even in case there was no existing bias add op in given conv2D op, use_bias=False)
# check and break if idx_param_quantizer is out of range (Batchnorm fold will update bias tensor,
# even in case there was no existing bias add op in given conv2D op, use_bias=False)
if idx_param_quantizer == len(self.param_quantizers):
break
if self._layer_to_wrap.weights[idx].dtype in QUANT_ALLOWED_DTYPES:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,18 @@ def get_encodings_dict(self) -> Dict[str, Union[str, Dict]]:
# inbound_nodes parameter populated, so the name of the quantizer is used instead
if not wrapper._layer_to_wrap.inbound_nodes:
tensor_name = "multi_head_attention/" + wrapper.name + "/" + output_quantizer.name
elif isinstance(wrapper._layer_to_wrap.output, List):
tensor_name = wrapper._layer_to_wrap.output[idx].name
else:
tensor_name = wrapper._layer_to_wrap.output.name
encoding_dict = self._get_encoding_dict_for_quantizer(output_quantizer)
activation_encodings[tensor_name] = encoding_dict
encodings_dict = {'version': encoding_version,
'activation_encodings': activation_encodings,
'param_encodings': param_encodings,
'quantizer_args': self.quant_args if hasattr(self, "quant_args") else {}}
return encodings_dict
return {
'version': encoding_version,
'activation_encodings': activation_encodings,
'param_encodings': param_encodings,
'quantizer_args': self.quant_args if hasattr(self, "quant_args") else {}
}

def compute_encodings(self, forward_pass_callback, forward_pass_callback_args):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,16 @@
""" Utilities for parsing and applying quantsim configurations from json config file """
from typing import List, Tuple, Dict, Union

import tensorflow as tf
from tensorflow.keras import layers

from packaging import version
if version.parse(tf.version.VERSION) >= version.parse("2.10.1"):
from keras.layers.core.tf_op_layer import TFOpLambda # pylint: disable=import-error
else:
from tensorflow.python.keras.layers.core import TFOpLambda # pylint: disable=ungrouped-imports

# pylint: disable=wrong-import-position
from aimet_common.connected_graph.connectedgraph_utils import get_all_input_ops, get_all_output_ops
from aimet_common.connected_graph.operation import Op
from aimet_common.defs import QuantScheme, QuantizationDataType
Expand Down Expand Up @@ -161,8 +169,11 @@ def _initialize_input_quantizers(layer: layers.Layer, quant_settings: QuantizerS
:param enabled: Flag for quantized or not
:return: Input quantizers corresponding to layer
"""
layer_input_list = layer.inbound_nodes[0].keras_inputs
num_inputs = len(layer_input_list)

num_inputs = len(layer.inbound_nodes[0].keras_inputs)
# Special case for TFOpLambda layers as the input can be other Keras layers, tf operations, or static tf.tensors
if isinstance(layer, TFOpLambda):
num_inputs = len(layer.input) if isinstance(layer.input, List) else num_inputs
input_quantizers = []
for i in range(num_inputs):
activation_tensor_quantizer = ActivationTensorQuantizer(layer,
Expand All @@ -189,18 +200,23 @@ def _initialize_output_quantizers(layer: layers.Layer, quant_settings: Quantizer
:param enabled: Flag for quantized or not
:return: Output quantizers corresponding to layer
"""

# `layer.output` will be a list if there is more than one otherwise it's just a single output
num_outputs = len(layer.output) if isinstance(layer.output, List) else 1
output_quantizers = []
activation_tensor_quantizer = ActivationTensorQuantizer(layer,
f"{layer.name}_output_quantizer_0",
quant_settings.quant_scheme,
quant_settings.round_mode,
quant_settings.bitwidth,
quant_settings.data_type,
quant_settings.is_symmetric,
quant_settings.use_strict_symmetric,
quant_settings.use_unsigned_symmetric,
enabled and layer.output.dtype in QUANT_ALLOWED_DTYPES)
output_quantizers.append(activation_tensor_quantizer)
for idx in range(num_outputs):
layer_output_dtype = layer.output[idx].dtype if isinstance(layer.output, List) else layer.output.dtype
activation_tensor_quantizer = ActivationTensorQuantizer(layer,
f"{layer.name}_output_quantizer_{idx}",
quant_settings.quant_scheme,
quant_settings.round_mode,
quant_settings.bitwidth,
quant_settings.data_type,
quant_settings.is_symmetric,
quant_settings.use_strict_symmetric,
quant_settings.use_unsigned_symmetric,
enabled and layer_output_dtype in QUANT_ALLOWED_DTYPES)
output_quantizers.append(activation_tensor_quantizer)
return output_quantizers


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def test_multi_output_only_lambda():
orig_output = original_model(random_input)

functional_model = prepare_model(original_model)

ConnectedGraph(functional_model)
functional_model_output = functional_model(random_input)
model_weights_in_correct_order = _get_original_models_weights_in_functional_model_order(
original_model, functional_model, class_names=set())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,34 @@ def model_with_lambda_operators():
model = tf.keras.Model(inputs=(inp, inp_2), outputs=x, name="model_with_lambda_operators")
return model

def model_with_tf_op_lambda_operators():

def model_with_tf_op_lambda_operators_multi_tf_keras_input():
input_layer = tf.keras.Input(batch_input_shape=(1, 16, 32, 3))
x1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)(input_layer)
x2 = tf.transpose(x1, perm=[0, 1, 3, 2])
output = tf.matmul(x1, x2)

model = tf.keras.Model(inputs=input_layer, outputs=output, name="model_with_tf_op_lambda_layers")
out = model(tf.random.uniform((1, 16, 32, 3)))
return tf.keras.Model(
inputs=input_layer,
outputs=output,
name="model_with_tf_op_lambda_operators_multi_tf_keras_input"
)

return model

def model_with_tf_op_lambda_operators_multi_tf_static_inputs():
input_1 = tf.keras.Input(shape=(10,))
input_2 = tf.keras.Input(shape=(20,))
keras_concat = tf.keras.layers.Concatenate(axis=-1)([input_1, input_2])
const = tf.constant(3.14, dtype=tf.float32, shape=(1,))
const_expanded = tf.expand_dims(const, axis=0)
tf_concat = tf.concat([keras_concat, const_expanded], axis=-1)
output = tf.keras.layers.Dense(3)(tf_concat)

return tf.keras.Model(
inputs=[input_1, input_2],
outputs=output,
name="model_with_tf_op_lambda_operators_multi_tf_static_inputs"
)

def model_with_reused_layer():
relu = tf.keras.layers.ReLU()
Expand Down Expand Up @@ -384,8 +402,8 @@ def test_model_with_lambda_operators():
assert len(encodings['param_encodings']) == 2


def test_model_with_tf_op_lambda_operators():
model = model_with_tf_op_lambda_operators()
def test_model_with_tf_op_lambda_operators_multi_tf_keras_input():
model = model_with_tf_op_lambda_operators_multi_tf_keras_input()
random_input = tf.random.uniform((1, 16, 32, 3))

with tempfile.TemporaryDirectory() as temp_dir:
Expand All @@ -405,6 +423,34 @@ def test_model_with_tf_op_lambda_operators():
assert len(encodings['activation_encodings']) == 4
assert len(encodings['param_encodings']) == 1, "Only the Dense layer in this model should have param_encoding"


def test_model_with_tf_op_lambda_operators_multi_tf_static_inputs():
model = model_with_tf_op_lambda_operators_multi_tf_static_inputs()
random_input = [tf.random.uniform(shape=(1, *shape[1:])) for shape in model.input_shape]

qsim = QuantizationSimModel(model, quant_scheme="tf")
qsim.compute_encodings(lambda m, _: m(random_input), None)
with tempfile.TemporaryDirectory() as temp_dir:
qsim.export(temp_dir, model.name, convert_to_pb=False)

with open(os.path.join(temp_dir, f"{model.name}.encodings"), "r") as encodings_file:
encodings = json.load(encodings_file)

assert isinstance(qsim.model.layers[2].original_layer, tf.keras.layers.Concatenate), \
"The layer should have a Concatenate layer"

assert qsim.model.layers[3].original_layer.get_config()["name"] == "tf.concat", \
"The layer should have a tf.concat"

assert len(qsim.model.layers[2].input_quantizers) == 2,\
"This QCQuantizeWrapper should have tf.keras.layers.Concatenate and have 2 input_quantizers"

assert len(qsim.model.layers[3].input_quantizers) == 2, \
"This QCQuantizeWrapper should have the tf.concat TFOpLambda layer and have 2 input_quantizers"

assert len(encodings["activation_encodings"]) == 5
assert len(encodings['param_encodings']) == 1, "Only the Dense layer in this model should have param_encoding"

def test_qat():
if version.parse(tf.version.VERSION) >= version.parse("2.00"):
model = dense_functional()
Expand Down Expand Up @@ -1266,7 +1312,6 @@ def test_quant_scheme_percentile():
assert np.allclose(quantizer.get_percentile_value(), 99.99)



def test_quant_scheme_percentile_setting_using_str():
"""
This test case ensures that the quantization is working fine with percentile scheme
Expand All @@ -1282,3 +1327,41 @@ def test_quant_scheme_percentile_setting_using_str():
assert quantizer.quant_scheme == QuantScheme.post_training_percentile


def test_multi_output_model():
"""
Test Quantsim with a model that has multi output layers
"""

inputs = tf.keras.Input(shape=(480, 1088, 3))
x = tf.keras.layers.Conv2D(
filters=3,
kernel_size=(3, 3),
activation="relu",
padding="same"
)(inputs)

# tf operators -> TFOpLambda
x = tf.reshape(x, [1, -1, 2])
x, y = tf.split(x, 2, axis=2)
x = tf.concat([x, y], axis=2)

output_1 = tf.keras.layers.Dense(units=32)(x)
output_2 = tf.keras.layers.Dense(units=32)(x)

model = tf.keras.Model(inputs=[inputs], outputs=[output_1, output_2])

sim = QuantizationSimModel(model)

# Check ConnectedGraph
assert sim.connected_graph._split_count == 1
for idx in range(2):
assert sim.connected_graph._name_to_layer[f"tf.split/split:{idx}"].get_config()["name"] == "tf.split"

assert len(sim.model.layers[3].output_quantizers) == 2

sim.compute_encodings(
lambda m, _: m(tf.random.uniform(shape=(1, *model.input_shape[1:]))), None
)

with tempfile.TemporaryDirectory() as tmp_dir:
sim.export(tmp_dir, "multi_output_model")

0 comments on commit df3cb8c

Please sign in to comment.