From b26a141a25c52834e2ffe346739f84dd19a59d96 Mon Sep 17 00:00:00 2001 From: Matthew Ernst Date: Tue, 13 Feb 2024 17:27:08 -0800 Subject: [PATCH] add special workaround for relu encoding min for qat (#2739) * add special workaround for relu encoding min for qat Signed-off-by: Matthew Ernst --- .../python/aimet_tensorflow/keras/quantsim.py | 35 ++++++++-- .../test/python/eager/test_quantsim_keras.py | 66 ++++++++++++------- 2 files changed, 72 insertions(+), 29 deletions(-) diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim.py index 264161a9700..0ba09db23fc 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim.py @@ -67,6 +67,7 @@ tf.keras.layers.MultiHeadAttention: QcQuantizableMultiHeadAttention } + @dataclass class QuantizationSimModelParams: """ @@ -80,6 +81,7 @@ class QuantizationSimModelParams: config_file: str = None default_data_type: QuantizationDataType = QuantizationDataType.int + # pylint: disable=too-many-ancestors # pylint: disable=too-many-instance-attributes class QuantizationSimModel(tf.keras.Model): @@ -116,7 +118,7 @@ def __init__(self, model, quant_scheme: Union[QuantScheme, str] = 'tf_enhanced', n_weights = len(self._model_without_wrappers.weights) self._model_without_wrappers.set_weights(model.get_weights()[:n_weights]) self._layer_name_to_quant_wrapper = {} - self._substituted_layer = {} # to hold the substituted layers + self._substituted_layer = {} # to hold the substituted layers self._validate_model() self.connected_graph = ConnectedGraph(self._model_without_wrappers) self._quantsim_configurator = self._initialize_quantsim_configurator(quant_scheme, rounding_mode, @@ -536,7 +538,8 @@ def load_encodings_to_sim(self, encoding_file_path: str): if not input_quantizer.is_enabled(): _logger.info("Not loading encodings for quantizer: %s as it is disabled", tensor_name) continue - encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(activation_encodings[tensor_name][0]) + encoding, is_symmetric = keras_common_utils.create_encoding_from_dict( + activation_encodings[tensor_name][0]) input_quantizer.tensor_quantizer.isEncodingValid = True input_quantizer.set_quantizer_encodings(encoding.bw, is_symmetric, encoding, libpymo.TensorQuantizerOpMode.quantizeDequantize) @@ -554,12 +557,14 @@ def load_encodings_to_sim(self, encoding_file_path: str): _logger.info("Not loading encodings for parameter: %s as quantizer is disabled", param_name) continue if isinstance(param_quantizer, StaticGridPerChannelQuantizer): - encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(param_encodings[param_name]) + encoding, is_symmetric = keras_common_utils.create_encoding_from_dict( + param_encodings[param_name]) for tensor_quantizer in param_quantizer.tensor_quantizer: tensor_quantizer.isEncodingValid = True bw = encoding[0].bw else: - encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(param_encodings[param_name][0]) + encoding, is_symmetric = keras_common_utils.create_encoding_from_dict( + param_encodings[param_name][0]) param_quantizer.tensor_quantizer.isEncodingValid = True bw = encoding.bw param_quantizer.set_quantizer_encodings(bw, is_symmetric, encoding, @@ -568,7 +573,8 @@ def load_encodings_to_sim(self, encoding_file_path: str): else: if param_quantizer.is_enabled(): param_quantizer.disable() - _logger.info("Encoding for parameter: %s not present thus disabling this quantizer.", param_name) + _logger.info("Encoding for parameter: %s not present thus disabling this quantizer.", + param_name) # Loading encodings means that compute encodings was called. Therefore, these two lines set the correct # op mode for the correct quant scheme and if the quantization was per channel or not. @@ -587,7 +593,8 @@ def load_encodings_to_sim(self, encoding_file_path: str): if not output_quantizer.is_enabled(): _logger.info("Not loading encodings for quantizer: %s as it is disabled", tensor_name) continue - encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(activation_encodings[tensor_name][0]) + encoding, is_symmetric = keras_common_utils.create_encoding_from_dict( + activation_encodings[tensor_name][0]) output_quantizer.tensor_quantizer.isEncodingValid = True output_quantizer.set_quantizer_encodings(encoding.bw, is_symmetric, encoding, libpymo.TensorQuantizerOpMode.quantizeDequantize) @@ -631,6 +638,7 @@ def get_quant_wrapper_for_layer_name(self, layer_name: str) -> QcQuantizeWrapper """ return self._layer_name_to_quant_wrapper.get(layer_name) + # pylint: disable=too-many-locals def _fill_missing_encoding_min_max_gradients(self, gradients: list): """ Computes the encoding min/max gradients and populates the gradients list @@ -673,6 +681,21 @@ def _find_weight_in_layer(weight_name: str, model_layer: tf.keras.layers.Layer): gradients[enc_min_index] = dloss_by_dmin gradients[enc_max_index] = dloss_by_dmax + # TODO: Remove this logic once this has been resolved in QNN/SNPE + # Go through activation quantizers (Input/Output) and set any ReLU's encoding min to 0 + relu_quantize_wrappers = [ + _layer for _layer in self.model.layers + if isinstance(_layer, QcQuantizeWrapper) and isinstance(_layer.original_layer, tf.keras.layers.ReLU) + ] + + def _set_encoding_min_grad_to_None(quantizer): + enc_min_index = weight_name_to_index[quantizer.encoding_min.name] + gradients[enc_min_index] = None + + for relu_quantizer in relu_quantize_wrappers: + for output_quantizer in relu_quantizer.output_quantizers: + _set_encoding_min_grad_to_None(output_quantizer) + # pylint: disable=useless-super-delegation def get_config(self): return super().get_config() diff --git a/TrainingExtensions/tensorflow/test/python/eager/test_quantsim_keras.py b/TrainingExtensions/tensorflow/test/python/eager/test_quantsim_keras.py index 9b90c166d12..ac31cba4d99 100644 --- a/TrainingExtensions/tensorflow/test/python/eager/test_quantsim_keras.py +++ b/TrainingExtensions/tensorflow/test/python/eager/test_quantsim_keras.py @@ -55,10 +55,12 @@ from aimet_tensorflow.keras.quantsim import QuantizationSimModel from test_models_keras import tiny_conv_net + def conv_functional(): input_shape = (128, 28, 28, 1) inp = tf.keras.Input(shape=input_shape[1:]) - x = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu")(inp) + x = tf.keras.layers.Conv2D(32, kernel_size=(3, 3))(inp) + x = tf.keras.layers.ReLU()(x) x = tf.keras.layers.Conv2DTranspose(32, kernel_size=(3, 3), activation="relu")(x) x = tf.keras.layers.DepthwiseConv2D(depth_multiplier=1, kernel_size=(3, 3), activation='relu')(x) x = tf.keras.layers.Flatten()(x) @@ -68,6 +70,7 @@ def conv_functional(): model = tf.keras.Model(inputs=inp, outputs=x, name='conv_functional') return model + def dense_functional(): inp = tf.keras.layers.Input(shape=(5,)) x = tf.keras.layers.Dense(units=2)(inp) @@ -75,6 +78,7 @@ def dense_functional(): model = tf.keras.Model(inputs=inp, outputs=x, name="dense_functional") return model + def dense_sequential(): model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(units=2, input_shape=(5,))) @@ -102,6 +106,7 @@ def call(self, inputs, training=None, mask=None): x = self.softmax(x) return x + def model_with_lambda_operators(): inp = tf.keras.layers.Input(shape=(5,)) inp_2 = tf.keras.layers.Input(shape=(3,)) @@ -143,6 +148,7 @@ def model_with_tf_op_lambda_operators_multi_tf_static_inputs(): name="model_with_tf_op_lambda_operators_multi_tf_static_inputs" ) + def model_with_reused_layer(): relu = tf.keras.layers.ReLU() inp = tf.keras.layers.Input(shape=(5,)) @@ -153,6 +159,7 @@ def model_with_reused_layer(): model = tf.keras.Model(inputs=inp, outputs=x, name="model_with_reused_layer") return model + class DenseReluLayer(tf.keras.layers.Layer): def __init__(self, **kwargs): super(DenseReluLayer, self).__init__() @@ -164,6 +171,7 @@ def call(self, inputs): x = self.relu(x) return x + def test_quantsim_basic(): if version.parse(tf.version.VERSION) >= version.parse("2.00"): model = dense_functional() @@ -205,12 +213,14 @@ def test_quantsim_basic(): qsim.export('./data', 'test_export') + def test_quantsim_export_quantizer_args(): if version.parse(tf.version.VERSION) >= version.parse("2.00"): model = dense_functional() rand_inp = np.random.randn(100, 5) - qsim = QuantizationSimModel(model, quant_scheme=QuantScheme.post_training_tf_enhanced, default_param_bw=16, default_output_bw=16 ) + qsim = QuantizationSimModel(model, quant_scheme=QuantScheme.post_training_tf_enhanced, default_param_bw=16, + default_output_bw=16) qsim.export('./data', 'test_export_with_quant_args') @@ -226,6 +236,7 @@ def test_quantsim_export_quantizer_args(): assert quantizer_args["dtype"] == "int" assert quantizer_args["is_symmetric"] + def test_quantsim_with_custom_config_file(): quantsim_config = { "defaults": { @@ -473,7 +484,7 @@ def test_qat(): loss=tf.keras.losses.MeanSquaredError()) # Track weights for dense layer to check that they are updated during fit running_weights = [tf.keras.backend.get_value(param) for - param in qsim.model.layers[1]._layer_to_wrap.weights] + param in qsim.model.layers[1]._layer_to_wrap.weights] # Track encoding max for dense output quantizer to check that it is not updated during fit running_dense_output_quantizer_encoding_max = \ tf.keras.backend.get_value(qsim.model.layers[1].output_quantizers[0]._encoding_max) @@ -509,6 +520,7 @@ def test_qat(): assert file_count == 1, f"QAT Save Callback did not work" + def test_range_learning(): if version.parse(tf.version.VERSION) >= version.parse("2.00"): tf.keras.backend.clear_session() @@ -725,6 +737,7 @@ def test_quantizable_mha_with_mask(): # check that QcQuantizableMultiHeadAttention exists in QuantSim model.layers assert any(isinstance(layer, QcQuantizableMultiHeadAttention) for layer in quantized_model.model.layers) + def test_quantizable_mha_encodings(): B = 5 T = 8 @@ -762,6 +775,7 @@ def test_quantizable_mha_encodings(): assert all((quantized_model_tensor >= output_encoding_min - FLOAT_DELTA) & (quantized_model_tensor <= output_encoding_max + FLOAT_DELTA)) + def test_quantizable_mha_export_encodings(): B = 5 T = 8 @@ -800,6 +814,7 @@ def test_quantizable_mha_export_encodings(): assert param_name in encodings['param_encodings'] assert encodings['param_encodings'][param_name] == encoding_dict + def _common_stays_valid_after_export_helper(model, rand_inp, config=None): tf.keras.backend.clear_session() sim = QuantizationSimModel(model, quant_scheme='tf', config_file=config) @@ -845,7 +860,7 @@ def check_quantizers(original_quantizer, new_quantizer): if not isinstance(original_quantizer.tensor_quantizer, List): original_quantizer_tensor_quantizers = [original_quantizer.tensor_quantizer] new_quantizer_tensor_quantizers = [new_quantizer.tensor_quantizer] - else: + else: original_quantizer_tensor_quantizers = original_quantizer.tensor_quantizer new_quantizer_tensor_quantizers = new_quantizer.tensor_quantizer @@ -875,27 +890,32 @@ def check_encodings(original_encoding, new_encoding): if isinstance(layer, tf.keras.layers.InputLayer): continue - assert len(layer.input_quantizers) == len(original_layer_and_quantizers[layer.name]["input_quantizers"]), f"Not the same number of input quantizers for layer {layer.name}" + assert len(layer.input_quantizers) == len(original_layer_and_quantizers[layer.name][ + "input_quantizers"]), f"Not the same number of input quantizers for layer {layer.name}" for i, _ in enumerate(layer.input_quantizers): check_quantizers(original_layer_and_quantizers[layer.name]["input_quantizers"][i], - layer.input_quantizers[i]) + layer.input_quantizers[i]) check_encodings(original_layer_and_quantizers[layer.name]["input_quantizers"][i].encoding, layer.input_quantizers[i].encoding) - assert len(layer.output_quantizers) == len(original_layer_and_quantizers[layer.name]["output_quantizers"]), f"Not the same number of output quantizers for layer {layer.name}" + assert len(layer.output_quantizers) == len(original_layer_and_quantizers[layer.name][ + "output_quantizers"]), f"Not the same number of output quantizers for layer {layer.name}" for i, _ in enumerate(layer.output_quantizers): check_quantizers(original_layer_and_quantizers[layer.name]["output_quantizers"][i], - layer.output_quantizers[i]) + layer.output_quantizers[i]) check_encodings(original_layer_and_quantizers[layer.name]["output_quantizers"][i].encoding, layer.output_quantizers[i].encoding) - assert len(layer.param_quantizers) == len(original_layer_and_quantizers[layer.name]["param_quantizers"]), f"Not the same number of param quantizers for layer {layer.name}" + assert len(layer.param_quantizers) == len(original_layer_and_quantizers[layer.name][ + "param_quantizers"]), f"Not the same number of param quantizers for layer {layer.name}" for i, _ in enumerate(layer.param_quantizers): check_quantizers(original_layer_and_quantizers[layer.name]["param_quantizers"][i], - layer.param_quantizers[i]) + layer.param_quantizers[i]) check_encodings(original_layer_and_quantizers[layer.name]["param_quantizers"][i].encoding, layer.param_quantizers[i].encoding) - np.testing.assert_array_equal(original_sim_output, sim.model.predict(rand_inp), err_msg="Model output changed after export") + np.testing.assert_array_equal(original_sim_output, sim.model.predict(rand_inp), + err_msg="Model output changed after export") + def test_model_stays_valid_after_export_per_tensor(): model = conv_functional() @@ -998,7 +1018,7 @@ def test_load_encodings(): # For param expected_encoding = param_encodings['conv2d_1/kernel:0'][0] - actual_encoding = extracted_encoding["param_encodings"]['conv2d_1/kernel:0'][0] + actual_encoding = extracted_encoding["param_encodings"]['conv2d_1/kernel:0'][0] assert actual_encoding.get('bitwidth') == expected_encoding.get('bitwidth') assert actual_encoding.get('offset') == expected_encoding.get('offset') assert actual_encoding.get('is_symmetric') == expected_encoding.get('is_symmetric') @@ -1007,14 +1027,13 @@ def test_load_encodings(): # For activation expected_encoding = activation_encodings["conv2d_1/Tanh:0"][0] - actual_encoding = extracted_encoding["activation_encodings"]["conv2d_1/Tanh:0"][0] + actual_encoding = extracted_encoding["activation_encodings"]["conv2d_1/Tanh:0"][0] assert actual_encoding.get('bitwidth') == expected_encoding.get('bitwidth') assert actual_encoding.get('offset') == expected_encoding.get('offset') assert actual_encoding.get('is_symmetric') == expected_encoding.get('is_symmetric') assert np.allclose(actual_encoding.get('min'), expected_encoding.get('min'), atol=1e-5) assert np.allclose(actual_encoding.get('max'), expected_encoding.get('max'), atol=1e-5) - # Delete encodings JSON file if os.path.exists("./dummy.encodings"): os.remove("./dummy.encodings") @@ -1046,7 +1065,7 @@ def test_load_encodings_with_disabled_param(): model = keras_model() - sim = QuantizationSimModel(model,config_file='./quantsim_config.json') + sim = QuantizationSimModel(model, config_file='./quantsim_config.json') param_encodings = {'conv2d_1/kernel:0': [{'bitwidth': 4, 'is_symmetric': "False", 'max': 0.14584073424339294, 'min': -0.12761062383651733, @@ -1086,14 +1105,13 @@ def test_load_encodings_with_disabled_param(): # For activation expected_encoding = activation_encodings["conv2d_1/Tanh:0"][0] - actual_encoding = extracted_encoding["activation_encodings"]["conv2d_1/Tanh:0"][0] + actual_encoding = extracted_encoding["activation_encodings"]["conv2d_1/Tanh:0"][0] assert actual_encoding.get('bitwidth') == expected_encoding.get('bitwidth') assert actual_encoding.get('offset') == expected_encoding.get('offset') assert actual_encoding.get('is_symmetric') == expected_encoding.get('is_symmetric') assert np.allclose(actual_encoding.get('min'), expected_encoding.get('min'), atol=1e-5) assert np.allclose(actual_encoding.get('max'), expected_encoding.get('max'), atol=1e-5) - # Delete encodings JSON file if os.path.exists("./dummy.encodings"): os.remove("./dummy.encodings") @@ -1169,7 +1187,7 @@ def test_load_encodings_pcq(): # For param expected_encoding = param_encodings['conv2d_1/kernel:0'] - actual_encoding = extracted_encoding["param_encodings"]['conv2d_1/kernel:0'] + actual_encoding = extracted_encoding["param_encodings"]['conv2d_1/kernel:0'] for i in range(4): assert actual_encoding[i].get('bitwidth') == expected_encoding[i].get('bitwidth') assert actual_encoding[i].get('offset') == expected_encoding[i].get('offset') @@ -1179,18 +1197,18 @@ def test_load_encodings_pcq(): # For activation expected_encoding = activation_encodings["conv2d_1/Tanh:0"][0] - actual_encoding = extracted_encoding["activation_encodings"]["conv2d_1/Tanh:0"][0] + actual_encoding = extracted_encoding["activation_encodings"]["conv2d_1/Tanh:0"][0] assert actual_encoding.get('bitwidth') == expected_encoding.get('bitwidth') assert actual_encoding.get('offset') == expected_encoding.get('offset') assert actual_encoding.get('is_symmetric') == expected_encoding.get('is_symmetric') assert np.allclose(actual_encoding.get('min'), expected_encoding.get('min'), atol=1e-5) assert np.allclose(actual_encoding.get('max'), expected_encoding.get('max'), atol=1e-5) - # Delete encodings JSON file if os.path.exists("./dummy.encodings"): os.remove("./dummy.encodings") + @pytest.mark.cuda @pytest.mark.parametrize( "quant_scheme", @@ -1250,6 +1268,7 @@ def test_initialization_and_export_non_strict_symmetric(quant_scheme) -> None: assert np.isclose(encoding_min, scale * offset, atol=1e-6) assert np.isclose(encoding_max, encoding_min + scale * 255, atol=1e-6) + @pytest.mark.cuda @pytest.mark.parametrize( "quant_scheme", @@ -1339,6 +1358,7 @@ def test_initialization_and_export_non_strict_symmetric_per_channel(quant_scheme assert np.isclose(encoding_min, scale * offset, atol=1e-6) assert np.isclose(encoding_max, encoding_min + scale * 255, atol=1e-6) + def test_quant_scheme_percentile(): """ This test case ensures that the quantization is working fine with percentile scheme @@ -1347,7 +1367,8 @@ def test_quant_scheme_percentile(): if version.parse(tf.version.VERSION) >= version.parse("2.00"): model = dense_functional() - qsim = QuantizationSimModel(model, quant_scheme=QuantScheme.post_training_tf, default_param_bw=16, default_output_bw=16 ) + qsim = QuantizationSimModel(model, quant_scheme=QuantScheme.post_training_tf, default_param_bw=16, + default_output_bw=16) _, _, output_quantizers = qsim._get_quantizer_list() with pytest.raises(RuntimeError): for quantizer in output_quantizers: @@ -1367,13 +1388,12 @@ def test_quant_scheme_percentile_setting_using_str(): if version.parse(tf.version.VERSION) >= version.parse("2.00"): model = dense_functional() - qsim = QuantizationSimModel(model, quant_scheme="percentile", default_param_bw=16, default_output_bw=16 ) + qsim = QuantizationSimModel(model, quant_scheme="percentile", default_param_bw=16, default_output_bw=16) inp_quatizer, paramater_quantizer, output_quantizers = qsim._get_quantizer_list() for quantizer in inp_quatizer + paramater_quantizer + output_quantizers: assert quantizer.quant_scheme == QuantScheme.post_training_percentile - def test_multi_output_model(): """ Test Quantsim with a model that has multi output layers