Skip to content

Commit

Permalink
API to load and freeze given encodings (partial or complete) - PyTorch (
Browse files Browse the repository at this point in the history
#2621)

* Added an API to set and freeze partial encodings given by the user in the encodings file

* Changed method name to load_and_freeze_encodings, test case added and fixed pylint issues.

Signed-off-by: Sayanta Mukherjee <quic_ssayanta@quicinc.com>
  • Loading branch information
quic-ssayanta authored Dec 26, 2023
1 parent 8b646fa commit df6d822
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 2 deletions.
26 changes: 24 additions & 2 deletions TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def _set_quantizer_encodings(type_of_quantizer: str, quantizers: List[TensorQuan
:param type_of_quantizer: input or output
:param quantizers: input or output quantizers
"""
if type_of_quantizer in activation_encodings[module_name]:
if module_name in activation_encodings and type_of_quantizer in activation_encodings[module_name]:
encodings = activation_encodings[module_name][type_of_quantizer]
# The number of quantizers and encodings might not be same. For example, suppose the 1st output
# quantizer is disabled out of 4. The number of encodings will be 3, but number of output quantizers
Expand All @@ -390,6 +390,11 @@ def _set_quantizer_encodings(type_of_quantizer: str, quantizers: List[TensorQuan
raise RuntimeError("The quantsim passed for loading encodings does not have the same "
"configuration as the quantsim which was used to export the encodings")

if quantizer._is_encoding_frozen: # pylint: disable=protected-access
_logger.debug("Encodings are frozen for module %s and quantizer type %s", module_name,
type_of_quantizer)
continue

if encodings[ind]['dtype'] == 'int':
encoding, is_symmetric = utils.create_encoding_from_dict(encodings[ind])
quantizer.bitwidth = encoding.bw
Expand All @@ -414,7 +419,8 @@ def set_param_encoding(self, module_name: str, param_encodings: Dict):
"""
for orig_param_name, param_quantizer in self.param_quantizers.items():
param_name = module_name + '.' + orig_param_name
if param_name in param_encodings:
# pylint: disable=protected-access
if param_name in param_encodings and param_quantizer.enabled and not param_quantizer._is_encoding_frozen:
encodings = []
if param_encodings[param_name][0]['dtype'] == 'int':
is_symmetric = False
Expand Down Expand Up @@ -445,6 +451,22 @@ def freeze_param_encoding(self, module_name: str, param_encodings: Dict):
param_quantizer.freeze_encoding()
_logger.info("Freezing quantization encodings for parameter: %s", param_name)

def freeze_activation_encoding(self, name: str, activation_encoding: Dict):
"""
Freeze encodings for activation
:param module_name: name of module
:param activation_encodings: activation encodings dictionary
"""
for input_quantizer, output_quantizer in zip(self.input_quantizers, self.output_quantizers):
if name in activation_encoding:
if QUANTIZER_TYPE_INPUT in activation_encoding[name]:
input_quantizer.freeze_encoding()
_logger.info("Freezing quantization encodings for input activation: %s", name)
if QUANTIZER_TYPE_OUTPUT in activation_encoding[name]:
output_quantizer.freeze_encoding()
_logger.info("Freezing quantization encodings for output activation: %s", name)

@staticmethod
def should_perform_quant_dequant(tensor: torch.Tensor, tensor_quantizer: TensorQuantizer) -> bool:
"""
Expand Down
21 changes: 21 additions & 0 deletions TrainingExtensions/torch/src/python/aimet_torch/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,27 @@ def configure_quantization_ops(self, config_file: str, default_output_bw: int, d
return QuantSimConfigurator(self.model, self.connected_graph, config_file, default_output_bw,
default_param_bw, default_data_type)

def load_and_freeze_encodings(self, encoding_path: str):
"""
Functionality to set encodings (both activation and parameter) as per the given encodings JSON file and
freeze them.
:param encoding_path: JSON file path from where to load the encodings.
"""
with open(encoding_path, mode='r') as json_file:
encodings_dict = json.load(json_file)

params_encoding = encodings_dict['param_encodings']
activation_encoding = encodings_dict['activation_encodings']

for name, module in self.model.named_modules():
if isinstance(module, QcQuantizeWrapper):
module.set_param_encoding(name, params_encoding)
module.freeze_param_encoding(name, params_encoding)

module.set_activation_encoding(name, activation_encoding)
module.freeze_activation_encoding(name, activation_encoding)

def set_and_freeze_param_encodings(self, encoding_path: str):
"""
Set and freeze parameter encodings from encodings JSON file
Expand Down
177 changes: 177 additions & 0 deletions TrainingExtensions/torch/test/python/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2834,6 +2834,183 @@ def forward_pass_callback(model_list, _):
assert sim2.model.add.input_quantizers[0].encoding is not None
assert sim2.model.add.input_quantizers[1].encoding is not None

def test_load_and_freeze_encodings(self):
model = SmallMnist()
dummy_input = torch.rand(1, 1, 28, 28)

partial_torch_encodings = {
"activation_encodings": {
"conv1": {
"input": {
"0": {
"bitwidth": 8,
"dtype": "int",
"is_symmetric": "False",
"max": 0.9978924989700317,
"min": 0.0,
"offset": 0,
"scale": 0.003913303837180138
}
}
},
"conv2": {
"output": {
"0": {
"bitwidth": 8,
"dtype": "int",
"is_symmetric": "False",
"max": 0.4923851788043976,
"min": -0.43767568469047546,
"offset": -120,
"scale": 0.0036472973879426718
}
}
},
"fc2": {
"output": {
"0": {
"bitwidth": 8,
"dtype": "int",
"is_symmetric": "False",
"max": 0.1948324590921402,
"min": -0.15752412378787994,
"offset": -114,
"scale": 0.0013817904982715845
}
}
},
"relu1": {
"output": {
"0": {
"bitwidth": 8,
"dtype": "int",
"is_symmetric": "False",
"max": 1.0608084201812744,
"min": 0.0,
"offset": 0,
"scale": 0.004160033073276281
}
}
},
"relu3": {
"output": {
"0": {
"bitwidth": 8,
"dtype": "int",
"is_symmetric": "False",
"max": 0.5247029066085815,
"min": 0.0,
"offset": 0,
"scale": 0.0020576585084199905
}
}
}
},
"excluded_layers": [],
"param_encodings": {
"conv1.weight": [
{
"bitwidth": 4,
"dtype": "int",
"is_symmetric": "True",
"max": 0.18757757544517517,
"min": -0.2143743634223938,
"offset": -8,
"scale": 0.026796795427799225
}
],
"fc2.weight": [
{
"bitwidth": 4,
"dtype": "int",
"is_symmetric": "True",
"max": 0.13095608353614807,
"min": -0.14966410398483276,
"offset": -8,
"scale": 0.018708012998104095
}
]
},
"quantizer_args": {
"activation_bitwidth": 8,
"dtype": "int",
"is_symmetric": True,
"param_bitwidth": 4,
"per_channel_quantization": False,
"quant_scheme": "post_training_tf_enhanced"
},
"version": "0.6.1"
}

qsim = QuantizationSimModel(model=model, dummy_input=dummy_input, quant_scheme=QuantScheme.post_training_tf_enhanced,
rounding_mode='nearest', default_output_bw=16, default_param_bw=8, in_place=False,
config_file=None)
def forward_pass(model, dummy_input):
model.eval()
with torch.no_grad():
_ = model(dummy_input)

with open("./temp_partial_torch_encodings.encodings", 'w') as fp:
json.dump(partial_torch_encodings, fp)

qsim.load_and_freeze_encodings("./temp_partial_torch_encodings.encodings")

qsim.compute_encodings(forward_pass, dummy_input)
decimal_point_check = 6

def assert_input_output_quantizers(quantizers, quant_type):
for idx, io_quant in enumerate(quantizers):
assert io_quant.is_encoding_frozen == True
qsim_encodings = io_quant.encoding
actual_encodings = partial_torch_encodings['activation_encodings'][name][quant_type][str(idx)]
assert qsim_encodings.bw == actual_encodings['bitwidth']
np.testing.assert_almost_equal(qsim_encodings.delta, actual_encodings['scale'], decimal_point_check)
np.testing.assert_almost_equal(qsim_encodings.max, actual_encodings['max'], decimal_point_check)
np.testing.assert_almost_equal(qsim_encodings.min, actual_encodings['min'], decimal_point_check)
assert qsim_encodings.offset == actual_encodings['offset']

def assert_param_quantizers(param_quantizer, module_name, param_name):
qsim_computed_encodings = param_quantizer[param_name].encoding
qsim_computed_encodings = [qsim_computed_encodings] if not isinstance(qsim_computed_encodings, list) \
else qsim_computed_encodings
for idx, qsim_encoding in enumerate(qsim_computed_encodings):
actual_encodings = partial_torch_encodings['param_encodings'][module_name+"."+param_name][idx]
assert param_quantizer[param_name].is_encoding_frozen == True
assert qsim_encoding.bw == actual_encodings['bitwidth']
np.testing.assert_almost_equal(qsim_encoding.delta, actual_encodings['scale'], decimal_point_check)
np.testing.assert_almost_equal(qsim_encoding.max, actual_encodings['max'], decimal_point_check)
np.testing.assert_almost_equal(qsim_encoding.min, actual_encodings['min'], decimal_point_check)
assert qsim_encoding.offset == actual_encodings['offset']

input_quant_checked = []
output_quant_checked = []
param_quant_checked = []
for name, quant in qsim.quant_wrappers():
if name in partial_torch_encodings['activation_encodings']:
if 'input' in partial_torch_encodings['activation_encodings'][name]:
assert_input_output_quantizers(quant.input_quantizers, 'input')
input_quant_checked.append(name)
if 'output' in partial_torch_encodings['activation_encodings'][name]:
assert_input_output_quantizers(quant.output_quantizers, 'output')
output_quant_checked.append(name)

param_types = quant.param_quantizers.keys()
for param_type in param_types:
module_param_name = name + "." + param_type
if module_param_name in partial_torch_encodings['param_encodings']:
assert_param_quantizers(quant.param_quantizers, name, param_type)
param_quant_checked.append(module_param_name)

actual_input_quant = {k for k, v in partial_torch_encodings['activation_encodings'].items() if 'input' in v}
actual_output_quant = {k for k, v in partial_torch_encodings['activation_encodings'].items() if 'output' in v}
actual_param_quant = set(partial_torch_encodings['param_encodings'].keys())

assert actual_input_quant == set(input_quant_checked)
assert actual_output_quant == set(output_quant_checked)
assert actual_param_quant == set(param_quant_checked)

os.remove("./temp_partial_torch_encodings.encodings")


class TestQuantizationSimLearnedGrid:

Expand Down

0 comments on commit df6d822

Please sign in to comment.