Skip to content

Commit

Permalink
Keras - TFOpLambda Kwargs Input Quantizers (#2731)
Browse files Browse the repository at this point in the history
* Fixes input quantizers from kwargs with TFOpLambda layers

Signed-off-by: Matthew Ernst <quic_ernst@quicinc.com>
  • Loading branch information
quic-ernst authored Feb 12, 2024
1 parent f3e8487 commit 885bebb
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -324,17 +324,36 @@ def call(self, inputs, *args, **kwargs):
]

# TF functions like tf.concat could have two inputs in List form. But other layers could match
# The TFOpLambda where one input is in inputs and the other(s) are in a kwargs dict
num_inputs_to_quantize = len(inputs) if isinstance(inputs, List) else 1
# the TFOpLambda where one input is in `inputs` and the other(s) are in the kwargs dict
input_quantizer_index = len(inputs) if isinstance(inputs, List) else 1

# Quantize the input directly first
inputs = self._quantize_activation(
inputs,
quantizers=self.input_quantizers[:num_inputs_to_quantize],
quantizers=self.input_quantizers[:input_quantizer_index],
is_input_quantization=True
)
# Quantize any subsequent arguments
for tensor_name, input_quantizer in zip(kwargs_keys_for_keras_tensors, self.input_quantizers[num_inputs_to_quantize:]):
kwargs[tensor_name] = self._quantize_activation(kwargs[tensor_name], [input_quantizer], True)

# Quantize any subsequent arguments. We have to flatten the inputs here.
# Subsequent arguments could be a signular tensor or a list of tensors (e.g. tf.image.resize's `size`)
def kwarg_tensor_quantize(input_tensor):
nonlocal input_quantizer_index
if isinstance(input_tensor, List):
output = []
for inner_input in input_tensor:
output.append(kwarg_tensor_quantize(inner_input))
else:
output = self._quantize_activation(
input_tensor,
[self.input_quantizers[input_quantizer_index]],
True
)
input_quantizer_index += 1
return output

for key in kwargs_keys_for_keras_tensors:
kwargs[key] = kwarg_tensor_quantize(kwargs[key])

else:
inputs = self._quantize_activation(inputs, self.input_quantizers, True)
outputs = self._layer_to_wrap(inputs, *args, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def model_with_tf_op_lambda_operators_multi_tf_keras_input():
input_layer = tf.keras.Input(batch_input_shape=(1, 16, 32, 3))
x1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)(input_layer)
x2 = tf.transpose(x1, perm=[0, 1, 3, 2])
output = tf.matmul(x1, x2)

x = tf.matmul(x1, x2)
output = tf.image.resize(x, [tf.shape(input_layer)[1], tf.shape(input_layer)[2]])
return tf.keras.Model(
inputs=input_layer,
outputs=output,
Expand Down Expand Up @@ -425,12 +425,13 @@ def test_model_with_tf_op_lambda_operators_multi_tf_keras_input():
encodings = json.load(encodings_file)

assert "transpose" in qsim.model.layers[2].original_layer.name, "This QCQuantizeWrapper should wrap the `tf.transpose` TF Op Lambda Layer"
assert "matmul" in qsim.model.layers[3].original_layer.name, "This QCQuantizeWrapper should house the `tf.matmul` TF Op Lambda Layer"
assert "matmul" in qsim.model.layers[5].original_layer.name, "This QCQuantizeWrapper should house the `tf.matmul` TF Op Lambda Layer"
assert "image.resize" in qsim.model.layers[8].original_layer.name, "This QCQuantizeWrapper should house the `tf.image.resize` TF Op Lambda Layer"

assert len(qsim.model.layers[2].input_quantizers) == 1, "tf.transpose should have only 1 input_quantizer"
assert len(qsim.model.layers[3].input_quantizers) == 2, "tf.matmul should have 2 input_quantizer for a @ b"

assert len(encodings['activation_encodings']) == 4
assert len(qsim.model.layers[5].input_quantizers) == 2, "tf.matmul should have 2 input_quantizer for a @ b"
assert len(qsim.model.layers[8].input_quantizers) == 3, "tf.image.resize should have 3 input_quantizers for input, and the tensors for the new height and width for tf.shape"
assert len(encodings['activation_encodings']) == 5
assert len(encodings['param_encodings']) == 1, "Only the Dense layer in this model should have param_encoding"


Expand Down

0 comments on commit 885bebb

Please sign in to comment.