From 9ff81dc15f198b3fa1cb9cba89cf66e3ffce2f4d Mon Sep 17 00:00:00 2001 From: Matthew Ernst Date: Tue, 9 Jan 2024 10:50:01 -0800 Subject: [PATCH] Keras QAT Error with tf.IndexedSlices (#2640) * Add decorator to convert tf.IndexSlices to tf.tensor for grads Signed-off-by: Matthew Ernst --- .../keras/quant_sim/tensor_quantizer.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/tensor_quantizer.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/tensor_quantizer.py index ee49b6bd977..612a3c436a6 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/tensor_quantizer.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/tensor_quantizer.py @@ -37,7 +37,7 @@ """ Tensor quantizer for tf 2 keras """ import abc import functools -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Callable import tensorflow as tf import tensorflow.keras.backend as K @@ -84,6 +84,21 @@ def _handle(cls, tensor): return _handle +def _update_grad_to_tf_tensor_if_needed(grad_func: Callable): + """ + Decorator function to convert gradient tensors that are represented as tf.IndexedSlices into + tf.Tensor's. This typically occurs with operations such as tf.gather and keras.Embedding layers. + + :param grad_func: The gradient function that will be called after conversion + """ + def wrapper(*args, **kwargs): + grad = args[0] + if isinstance(grad, tf.IndexedSlices): + _logger.debug("Converting %s from tf.IndexedSlices to tf.Tensor", grad.name) + args = (tf.convert_to_tensor(grad),) + args[1:] + return grad_func(*args, **kwargs) + return wrapper + class TensorQuantizer(tf.keras.layers.Layer, abc.ABC): """Tensor quantizer class containing the bare bones of a given Quantizer""" @@ -421,6 +436,7 @@ def call_quantize_straight_through_estimator_grad(self, tensor: tf.Tensor): :param tensor: Tensor to quantize """ + @_update_grad_to_tf_tensor_if_needed def grad(upstream: tf.Tensor, variables: List): """ Straight through estimator grad function @@ -428,6 +444,7 @@ def grad(upstream: tf.Tensor, variables: List): :param variables: Variables used in forward pass to return gradients for """ assert len(variables) == 2, 'len variables is ' + str(len(variables)) + return qc_straight_through_estimator_grad(tensor, self._encoding_min, self._encoding_max, self._quantizer_mode, upstream) @@ -447,6 +464,7 @@ def call_quantsim_custom_grad_learned_grid(self, tensor: tf.Tensor): :param tensor: Tensor to quantize """ + @_update_grad_to_tf_tensor_if_needed def grad(upstream: tf.Tensor, variables: List): """ Range learning grad function