Skip to content

Commit

Permalink
Keras QAT Error with tf.IndexedSlices (#2640)
Browse files Browse the repository at this point in the history
* Add decorator to convert tf.IndexSlices to tf.tensor for grads

Signed-off-by: Matthew Ernst <quic_ernst@quicinc.com>
  • Loading branch information
quic-ernst authored Jan 9, 2024
1 parent c18aade commit 9ff81dc
Showing 1 changed file with 19 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
""" Tensor quantizer for tf 2 keras """
import abc
import functools
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Callable
import tensorflow as tf
import tensorflow.keras.backend as K

Expand Down Expand Up @@ -84,6 +84,21 @@ def _handle(cls, tensor):
return _handle


def _update_grad_to_tf_tensor_if_needed(grad_func: Callable):
"""
Decorator function to convert gradient tensors that are represented as tf.IndexedSlices into
tf.Tensor's. This typically occurs with operations such as tf.gather and keras.Embedding layers.
:param grad_func: The gradient function that will be called after conversion
"""
def wrapper(*args, **kwargs):
grad = args[0]
if isinstance(grad, tf.IndexedSlices):
_logger.debug("Converting %s from tf.IndexedSlices to tf.Tensor", grad.name)
args = (tf.convert_to_tensor(grad),) + args[1:]
return grad_func(*args, **kwargs)
return wrapper

class TensorQuantizer(tf.keras.layers.Layer, abc.ABC):
"""Tensor quantizer class containing the bare bones of a given Quantizer"""

Expand Down Expand Up @@ -421,13 +436,15 @@ def call_quantize_straight_through_estimator_grad(self, tensor: tf.Tensor):
:param tensor: Tensor to quantize
"""

@_update_grad_to_tf_tensor_if_needed
def grad(upstream: tf.Tensor, variables: List):
"""
Straight through estimator grad function
:param upstream: Gradient from child layers
:param variables: Variables used in forward pass to return gradients for
"""
assert len(variables) == 2, 'len variables is ' + str(len(variables))

return qc_straight_through_estimator_grad(tensor, self._encoding_min, self._encoding_max,
self._quantizer_mode, upstream)

Expand All @@ -447,6 +464,7 @@ def call_quantsim_custom_grad_learned_grid(self, tensor: tf.Tensor):
:param tensor: Tensor to quantize
"""

@_update_grad_to_tf_tensor_if_needed
def grad(upstream: tf.Tensor, variables: List):
"""
Range learning grad function
Expand Down

0 comments on commit 9ff81dc

Please sign in to comment.