diff --git a/TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py b/TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py index bab54d13017..48612d00cf1 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py @@ -373,7 +373,7 @@ def _set_quantizer_encodings(type_of_quantizer: str, quantizers: List[TensorQuan :param type_of_quantizer: input or output :param quantizers: input or output quantizers """ - if type_of_quantizer in activation_encodings[module_name]: + if module_name in activation_encodings and type_of_quantizer in activation_encodings[module_name]: encodings = activation_encodings[module_name][type_of_quantizer] # The number of quantizers and encodings might not be same. For example, suppose the 1st output # quantizer is disabled out of 4. The number of encodings will be 3, but number of output quantizers @@ -390,6 +390,11 @@ def _set_quantizer_encodings(type_of_quantizer: str, quantizers: List[TensorQuan raise RuntimeError("The quantsim passed for loading encodings does not have the same " "configuration as the quantsim which was used to export the encodings") + if quantizer._is_encoding_frozen: # pylint: disable=protected-access + _logger.debug("Encodings are frozen for module %s and quantizer type %s", module_name, + type_of_quantizer) + continue + if encodings[ind]['dtype'] == 'int': encoding, is_symmetric = utils.create_encoding_from_dict(encodings[ind]) quantizer.bitwidth = encoding.bw @@ -414,7 +419,8 @@ def set_param_encoding(self, module_name: str, param_encodings: Dict): """ for orig_param_name, param_quantizer in self.param_quantizers.items(): param_name = module_name + '.' + orig_param_name - if param_name in param_encodings: + # pylint: disable=protected-access + if param_name in param_encodings and param_quantizer.enabled and not param_quantizer._is_encoding_frozen: encodings = [] if param_encodings[param_name][0]['dtype'] == 'int': is_symmetric = False @@ -445,6 +451,22 @@ def freeze_param_encoding(self, module_name: str, param_encodings: Dict): param_quantizer.freeze_encoding() _logger.info("Freezing quantization encodings for parameter: %s", param_name) + def freeze_activation_encoding(self, name: str, activation_encoding: Dict): + """ + Freeze encodings for activation + + :param module_name: name of module + :param activation_encodings: activation encodings dictionary + """ + for input_quantizer, output_quantizer in zip(self.input_quantizers, self.output_quantizers): + if name in activation_encoding: + if QUANTIZER_TYPE_INPUT in activation_encoding[name]: + input_quantizer.freeze_encoding() + _logger.info("Freezing quantization encodings for input activation: %s", name) + if QUANTIZER_TYPE_OUTPUT in activation_encoding[name]: + output_quantizer.freeze_encoding() + _logger.info("Freezing quantization encodings for output activation: %s", name) + @staticmethod def should_perform_quant_dequant(tensor: torch.Tensor, tensor_quantizer: TensorQuantizer) -> bool: """ diff --git a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py index c637adfb6d6..e8190ef72d9 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py @@ -1599,6 +1599,27 @@ def configure_quantization_ops(self, config_file: str, default_output_bw: int, d return QuantSimConfigurator(self.model, self.connected_graph, config_file, default_output_bw, default_param_bw, default_data_type) + def load_and_freeze_encodings(self, encoding_path: str): + """ + Functionality to set encodings (both activation and parameter) as per the given encodings JSON file and + freeze them. + + :param encoding_path: JSON file path from where to load the encodings. + """ + with open(encoding_path, mode='r') as json_file: + encodings_dict = json.load(json_file) + + params_encoding = encodings_dict['param_encodings'] + activation_encoding = encodings_dict['activation_encodings'] + + for name, module in self.model.named_modules(): + if isinstance(module, QcQuantizeWrapper): + module.set_param_encoding(name, params_encoding) + module.freeze_param_encoding(name, params_encoding) + + module.set_activation_encoding(name, activation_encoding) + module.freeze_activation_encoding(name, activation_encoding) + def set_and_freeze_param_encodings(self, encoding_path: str): """ Set and freeze parameter encodings from encodings JSON file diff --git a/TrainingExtensions/torch/test/python/test_quantizer.py b/TrainingExtensions/torch/test/python/test_quantizer.py index 9472a9014d7..7834e5f7c6a 100644 --- a/TrainingExtensions/torch/test/python/test_quantizer.py +++ b/TrainingExtensions/torch/test/python/test_quantizer.py @@ -2834,6 +2834,183 @@ def forward_pass_callback(model_list, _): assert sim2.model.add.input_quantizers[0].encoding is not None assert sim2.model.add.input_quantizers[1].encoding is not None + def test_load_and_freeze_encodings(self): + model = SmallMnist() + dummy_input = torch.rand(1, 1, 28, 28) + + partial_torch_encodings = { + "activation_encodings": { + "conv1": { + "input": { + "0": { + "bitwidth": 8, + "dtype": "int", + "is_symmetric": "False", + "max": 0.9978924989700317, + "min": 0.0, + "offset": 0, + "scale": 0.003913303837180138 + } + } + }, + "conv2": { + "output": { + "0": { + "bitwidth": 8, + "dtype": "int", + "is_symmetric": "False", + "max": 0.4923851788043976, + "min": -0.43767568469047546, + "offset": -120, + "scale": 0.0036472973879426718 + } + } + }, + "fc2": { + "output": { + "0": { + "bitwidth": 8, + "dtype": "int", + "is_symmetric": "False", + "max": 0.1948324590921402, + "min": -0.15752412378787994, + "offset": -114, + "scale": 0.0013817904982715845 + } + } + }, + "relu1": { + "output": { + "0": { + "bitwidth": 8, + "dtype": "int", + "is_symmetric": "False", + "max": 1.0608084201812744, + "min": 0.0, + "offset": 0, + "scale": 0.004160033073276281 + } + } + }, + "relu3": { + "output": { + "0": { + "bitwidth": 8, + "dtype": "int", + "is_symmetric": "False", + "max": 0.5247029066085815, + "min": 0.0, + "offset": 0, + "scale": 0.0020576585084199905 + } + } + } + }, + "excluded_layers": [], + "param_encodings": { + "conv1.weight": [ + { + "bitwidth": 4, + "dtype": "int", + "is_symmetric": "True", + "max": 0.18757757544517517, + "min": -0.2143743634223938, + "offset": -8, + "scale": 0.026796795427799225 + } + ], + "fc2.weight": [ + { + "bitwidth": 4, + "dtype": "int", + "is_symmetric": "True", + "max": 0.13095608353614807, + "min": -0.14966410398483276, + "offset": -8, + "scale": 0.018708012998104095 + } + ] + }, + "quantizer_args": { + "activation_bitwidth": 8, + "dtype": "int", + "is_symmetric": True, + "param_bitwidth": 4, + "per_channel_quantization": False, + "quant_scheme": "post_training_tf_enhanced" + }, + "version": "0.6.1" + } + + qsim = QuantizationSimModel(model=model, dummy_input=dummy_input, quant_scheme=QuantScheme.post_training_tf_enhanced, + rounding_mode='nearest', default_output_bw=16, default_param_bw=8, in_place=False, + config_file=None) + def forward_pass(model, dummy_input): + model.eval() + with torch.no_grad(): + _ = model(dummy_input) + + with open("./temp_partial_torch_encodings.encodings", 'w') as fp: + json.dump(partial_torch_encodings, fp) + + qsim.load_and_freeze_encodings("./temp_partial_torch_encodings.encodings") + + qsim.compute_encodings(forward_pass, dummy_input) + decimal_point_check = 6 + + def assert_input_output_quantizers(quantizers, quant_type): + for idx, io_quant in enumerate(quantizers): + assert io_quant.is_encoding_frozen == True + qsim_encodings = io_quant.encoding + actual_encodings = partial_torch_encodings['activation_encodings'][name][quant_type][str(idx)] + assert qsim_encodings.bw == actual_encodings['bitwidth'] + np.testing.assert_almost_equal(qsim_encodings.delta, actual_encodings['scale'], decimal_point_check) + np.testing.assert_almost_equal(qsim_encodings.max, actual_encodings['max'], decimal_point_check) + np.testing.assert_almost_equal(qsim_encodings.min, actual_encodings['min'], decimal_point_check) + assert qsim_encodings.offset == actual_encodings['offset'] + + def assert_param_quantizers(param_quantizer, module_name, param_name): + qsim_computed_encodings = param_quantizer[param_name].encoding + qsim_computed_encodings = [qsim_computed_encodings] if not isinstance(qsim_computed_encodings, list) \ + else qsim_computed_encodings + for idx, qsim_encoding in enumerate(qsim_computed_encodings): + actual_encodings = partial_torch_encodings['param_encodings'][module_name+"."+param_name][idx] + assert param_quantizer[param_name].is_encoding_frozen == True + assert qsim_encoding.bw == actual_encodings['bitwidth'] + np.testing.assert_almost_equal(qsim_encoding.delta, actual_encodings['scale'], decimal_point_check) + np.testing.assert_almost_equal(qsim_encoding.max, actual_encodings['max'], decimal_point_check) + np.testing.assert_almost_equal(qsim_encoding.min, actual_encodings['min'], decimal_point_check) + assert qsim_encoding.offset == actual_encodings['offset'] + + input_quant_checked = [] + output_quant_checked = [] + param_quant_checked = [] + for name, quant in qsim.quant_wrappers(): + if name in partial_torch_encodings['activation_encodings']: + if 'input' in partial_torch_encodings['activation_encodings'][name]: + assert_input_output_quantizers(quant.input_quantizers, 'input') + input_quant_checked.append(name) + if 'output' in partial_torch_encodings['activation_encodings'][name]: + assert_input_output_quantizers(quant.output_quantizers, 'output') + output_quant_checked.append(name) + + param_types = quant.param_quantizers.keys() + for param_type in param_types: + module_param_name = name + "." + param_type + if module_param_name in partial_torch_encodings['param_encodings']: + assert_param_quantizers(quant.param_quantizers, name, param_type) + param_quant_checked.append(module_param_name) + + actual_input_quant = {k for k, v in partial_torch_encodings['activation_encodings'].items() if 'input' in v} + actual_output_quant = {k for k, v in partial_torch_encodings['activation_encodings'].items() if 'output' in v} + actual_param_quant = set(partial_torch_encodings['param_encodings'].keys()) + + assert actual_input_quant == set(input_quant_checked) + assert actual_output_quant == set(output_quant_checked) + assert actual_param_quant == set(param_quant_checked) + + os.remove("./temp_partial_torch_encodings.encodings") + class TestQuantizationSimLearnedGrid: