Skip to content

Commit

Permalink
add special workaround for relu encoding min for qat (#2739)
Browse files Browse the repository at this point in the history
* add special workaround for relu encoding min for qat

Signed-off-by: Matthew Ernst <quic_ernst@quicinc.com>
  • Loading branch information
quic-ernst authored Feb 14, 2024
1 parent a894f6b commit b26a141
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
tf.keras.layers.MultiHeadAttention: QcQuantizableMultiHeadAttention
}


@dataclass
class QuantizationSimModelParams:
"""
Expand All @@ -80,6 +81,7 @@ class QuantizationSimModelParams:
config_file: str = None
default_data_type: QuantizationDataType = QuantizationDataType.int


# pylint: disable=too-many-ancestors
# pylint: disable=too-many-instance-attributes
class QuantizationSimModel(tf.keras.Model):
Expand Down Expand Up @@ -116,7 +118,7 @@ def __init__(self, model, quant_scheme: Union[QuantScheme, str] = 'tf_enhanced',
n_weights = len(self._model_without_wrappers.weights)
self._model_without_wrappers.set_weights(model.get_weights()[:n_weights])
self._layer_name_to_quant_wrapper = {}
self._substituted_layer = {} # to hold the substituted layers
self._substituted_layer = {} # to hold the substituted layers
self._validate_model()
self.connected_graph = ConnectedGraph(self._model_without_wrappers)
self._quantsim_configurator = self._initialize_quantsim_configurator(quant_scheme, rounding_mode,
Expand Down Expand Up @@ -536,7 +538,8 @@ def load_encodings_to_sim(self, encoding_file_path: str):
if not input_quantizer.is_enabled():
_logger.info("Not loading encodings for quantizer: %s as it is disabled", tensor_name)
continue
encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(activation_encodings[tensor_name][0])
encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(
activation_encodings[tensor_name][0])
input_quantizer.tensor_quantizer.isEncodingValid = True
input_quantizer.set_quantizer_encodings(encoding.bw, is_symmetric, encoding,
libpymo.TensorQuantizerOpMode.quantizeDequantize)
Expand All @@ -554,12 +557,14 @@ def load_encodings_to_sim(self, encoding_file_path: str):
_logger.info("Not loading encodings for parameter: %s as quantizer is disabled", param_name)
continue
if isinstance(param_quantizer, StaticGridPerChannelQuantizer):
encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(param_encodings[param_name])
encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(
param_encodings[param_name])
for tensor_quantizer in param_quantizer.tensor_quantizer:
tensor_quantizer.isEncodingValid = True
bw = encoding[0].bw
else:
encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(param_encodings[param_name][0])
encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(
param_encodings[param_name][0])
param_quantizer.tensor_quantizer.isEncodingValid = True
bw = encoding.bw
param_quantizer.set_quantizer_encodings(bw, is_symmetric, encoding,
Expand All @@ -568,7 +573,8 @@ def load_encodings_to_sim(self, encoding_file_path: str):
else:
if param_quantizer.is_enabled():
param_quantizer.disable()
_logger.info("Encoding for parameter: %s not present thus disabling this quantizer.", param_name)
_logger.info("Encoding for parameter: %s not present thus disabling this quantizer.",
param_name)

# Loading encodings means that compute encodings was called. Therefore, these two lines set the correct
# op mode for the correct quant scheme and if the quantization was per channel or not.
Expand All @@ -587,7 +593,8 @@ def load_encodings_to_sim(self, encoding_file_path: str):
if not output_quantizer.is_enabled():
_logger.info("Not loading encodings for quantizer: %s as it is disabled", tensor_name)
continue
encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(activation_encodings[tensor_name][0])
encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(
activation_encodings[tensor_name][0])
output_quantizer.tensor_quantizer.isEncodingValid = True
output_quantizer.set_quantizer_encodings(encoding.bw, is_symmetric, encoding,
libpymo.TensorQuantizerOpMode.quantizeDequantize)
Expand Down Expand Up @@ -631,6 +638,7 @@ def get_quant_wrapper_for_layer_name(self, layer_name: str) -> QcQuantizeWrapper
"""
return self._layer_name_to_quant_wrapper.get(layer_name)

# pylint: disable=too-many-locals
def _fill_missing_encoding_min_max_gradients(self, gradients: list):
"""
Computes the encoding min/max gradients and populates the gradients list
Expand Down Expand Up @@ -673,6 +681,21 @@ def _find_weight_in_layer(weight_name: str, model_layer: tf.keras.layers.Layer):
gradients[enc_min_index] = dloss_by_dmin
gradients[enc_max_index] = dloss_by_dmax

# TODO: Remove this logic once this has been resolved in QNN/SNPE
# Go through activation quantizers (Input/Output) and set any ReLU's encoding min to 0
relu_quantize_wrappers = [
_layer for _layer in self.model.layers
if isinstance(_layer, QcQuantizeWrapper) and isinstance(_layer.original_layer, tf.keras.layers.ReLU)
]

def _set_encoding_min_grad_to_None(quantizer):
enc_min_index = weight_name_to_index[quantizer.encoding_min.name]
gradients[enc_min_index] = None

for relu_quantizer in relu_quantize_wrappers:
for output_quantizer in relu_quantizer.output_quantizers:
_set_encoding_min_grad_to_None(output_quantizer)

# pylint: disable=useless-super-delegation
def get_config(self):
return super().get_config()
Expand Down
Loading

0 comments on commit b26a141

Please sign in to comment.