diff --git a/TrainingExtensions/torch/src/python/aimet_torch/gptvq/gptvq_weight.py b/TrainingExtensions/torch/src/python/aimet_torch/gptvq/gptvq_weight.py index c64475b808e..07f58ab9b65 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/gptvq/gptvq_weight.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/gptvq/gptvq_weight.py @@ -52,7 +52,7 @@ from aimet_torch.gptvq.defs import GPTVQSupportedModules, GPTVQParameters from aimet_torch.gptvq.gptvq_optimizer import GPTVQOptimizer from aimet_torch.gptvq.utils import get_module_name_to_hessian_tensor -from aimet_torch.v1.quantsim import ExportableQuantModule +from aimet_torch.v1.quantsim import QuantizedModuleProtocol from aimet_torch.save_utils import SaveUtils from aimet_torch.utils import get_named_module from aimet_torch.v2.nn import BaseQuantizationMixin @@ -391,7 +391,7 @@ def _export_encodings_to_json(cls, json.dump(encoding, encoding_fp, sort_keys=True, indent=4) @staticmethod - def _update_param_encodings_dict(quant_module: ExportableQuantModule, + def _update_param_encodings_dict(quant_module: QuantizedModuleProtocol, name: str, param_encodings: Dict, rows_per_block: int): diff --git a/TrainingExtensions/torch/src/python/aimet_torch/layer_output_utils.py b/TrainingExtensions/torch/src/python/aimet_torch/layer_output_utils.py index 857f5f2c20b..535f6b95e43 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/layer_output_utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/layer_output_utils.py @@ -50,7 +50,7 @@ from aimet_common.utils import AimetLogger from aimet_common.layer_output_utils import SaveInputOutput, save_layer_output_names -from aimet_torch.v1.quantsim import ExportableQuantModule, QuantizationSimModel +from aimet_torch.v1.quantsim import QuantizedModuleProtocol, QuantizationSimModel from aimet_torch import utils from aimet_torch import torchscript_utils from aimet_torch.onnx_utils import OnnxSaver, OnnxExportApiArgs @@ -171,7 +171,7 @@ def __init__(self, model: torch.nn.Module, dir_path: str, naming_scheme: NamingS self.module_to_name_dict = utils.get_module_to_name_dict(model=model, prefix='') # Check whether the given model is quantsim model - self.is_quantsim_model = any(isinstance(module, (ExportableQuantModule, QcQuantizeRecurrent)) for module in model.modules()) + self.is_quantsim_model = any(isinstance(module, (QuantizedModuleProtocol, QcQuantizeRecurrent)) for module in model.modules()) # Obtain layer-name to layer-output name mapping self.layer_name_to_layer_output_dict = {} @@ -206,7 +206,7 @@ def get_outputs(self, input_batch: Union[torch.Tensor, List[torch.Tensor], Tuple if self.is_quantsim_model: # Apply record-output hook to QuantizeWrapper modules (one node above leaf node in model graph) utils.run_hook_for_layers_with_given_input(self.model, input_batch, self.record_outputs, - module_type_for_attaching_hook=(ExportableQuantModule, QcQuantizeRecurrent), + module_type_for_attaching_hook=(QuantizedModuleProtocol, QcQuantizeRecurrent), leaf_node_only=False) else: # Apply record-output hook to Original modules (leaf node in model graph) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/peft.py b/TrainingExtensions/torch/src/python/aimet_torch/peft.py index 150413c15e5..a1e5130cad3 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/peft.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/peft.py @@ -54,7 +54,7 @@ from aimet_torch.v1.nn.modules.custom import Add, Multiply from aimet_torch.v2.quantsim import QuantizationSimModel from aimet_torch.v2.quantization.affine import QuantizeDequantize -from aimet_torch.v1.quantsim import ExportableQuantModule +from aimet_torch.v1.quantsim import QuantizedModuleProtocol from aimet_torch.v2.nn import BaseQuantizationMixin @@ -396,7 +396,7 @@ def export_adapter_weights(self, sim: QuantizationSimModel, path: str, filename_ tensors = {} for module_name, module in sim.model.named_modules(): - if not isinstance(module, ExportableQuantModule): + if not isinstance(module, QuantizedModuleProtocol): continue org_name = module_name pt_name = self._get_module_name(module_name) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/save_utils.py b/TrainingExtensions/torch/src/python/aimet_torch/save_utils.py index b170c3ce346..0cb761b8694 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/save_utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/save_utils.py @@ -37,7 +37,7 @@ """ Utilities to save a models and related parameters """ -from aimet_torch.v1.quantsim import ExportableQuantModule +from aimet_torch.v1.quantsim import QuantizedModuleProtocol class SaveUtils: @@ -50,7 +50,7 @@ def remove_quantization_wrappers(module): :param module: Model """ for module_name, module_ref in module.named_children(): - if isinstance(module_ref, ExportableQuantModule): + if isinstance(module_ref, QuantizedModuleProtocol): setattr(module, module_name, module_ref.get_original_module()) # recursively call children modules else: diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v1/adaround/adaround_weight.py b/TrainingExtensions/torch/src/python/aimet_torch/v1/adaround/adaround_weight.py index b5c617f2c0e..99c6b9b1275 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v1/adaround/adaround_weight.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v1/adaround/adaround_weight.py @@ -54,7 +54,7 @@ from aimet_torch import utils from aimet_torch.save_utils import SaveUtils from aimet_torch.meta import connectedgraph_utils -from aimet_torch.v1.quantsim import QuantizationSimModel, QcQuantizeWrapper, ExportableQuantModule +from aimet_torch.v1.quantsim import QuantizationSimModel, QcQuantizeWrapper, QuantizedModuleProtocol from aimet_torch.v1.qc_quantize_op import StaticGridQuantWrapper, QcQuantizeOpMode from aimet_torch.v1.tensor_quantizer import TensorQuantizer from aimet_torch.v1.adaround.adaround_wrapper import AdaroundWrapper @@ -490,7 +490,7 @@ def _export_encodings_to_json(cls, path: str, filename_prefix: str, quant_sim: Q param_encodings = {} for name, quant_module in quant_sim.model.named_modules(): - if isinstance(quant_module, ExportableQuantModule) and \ + if isinstance(quant_module, QuantizedModuleProtocol) and \ isinstance(quant_module.get_original_module(), AdaroundSupportedModules): if 'weight' in quant_module.param_quantizers: @@ -505,7 +505,7 @@ def _export_encodings_to_json(cls, path: str, filename_prefix: str, quant_sim: Q json.dump(encoding, encoding_fp, sort_keys=True, indent=4) @classmethod - def _update_param_encodings_dict(cls, quant_module: ExportableQuantModule, name: str, param_encodings: Dict): + def _update_param_encodings_dict(cls, quant_module: QuantizedModuleProtocol, name: str, param_encodings: Dict): """ Add module's weight parameter encodings to dictionary to be used for exporting encodings :param quant_module: quant module @@ -560,7 +560,7 @@ def _override_param_bitwidth(model: torch.nn.Module, quant_sim: QuantizationSimM # Create a mapping of QuantSim model's AdaRoundable module name and their module name_to_module = {} for q_name, q_module in quant_sim.model.named_modules(): - if isinstance(q_module, ExportableQuantModule): + if isinstance(q_module, QuantizedModuleProtocol): if isinstance(q_module.get_original_module(), AdaroundSupportedModules): # pylint: disable=protected-access name_to_module[q_name] = q_module diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v1/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/v1/quantsim.py index 0bfbdb5796c..bf002d51823 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v1/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v1/quantsim.py @@ -130,7 +130,7 @@ def __init__(self, @runtime_checkable -class ExportableQuantModule(Protocol): +class QuantizedModuleProtocol(Protocol): """ Defines the minimum interface requirements for exporting encodings from a module. """ @@ -210,7 +210,7 @@ def get_original_module(self) -> torch.nn.Module: QcQuantizeWrapper, QcQuantizeStandAloneBase, QcQuantizeRecurrent, - ExportableQuantModule, + QuantizedModuleProtocol, LazyQuantizeWrapper, ) @@ -699,7 +699,7 @@ def get_activation_param_encodings(self): param_encodings = OrderedDict() for module_name, module in self.model.named_modules(): - if not isinstance(module, ExportableQuantModule): + if not isinstance(module, QuantizedModuleProtocol): continue activation_encodings[module_name] = defaultdict(OrderedDict) @@ -740,7 +740,7 @@ def exclude_layers_from_quantization(self, layers_to_exclude: List[torch.nn.Modu quant_layers_to_exclude = [] quant_cls = (QcQuantizeRecurrent, LazyQuantizeWrapper, - ExportableQuantModule) + QuantizedModuleProtocol) for layer in layers_to_exclude: for module in layer.modules(): if isinstance(module, quant_cls): @@ -881,7 +881,7 @@ def _find_next_downstream_modules(op): @staticmethod - def _get_torch_encodings_for_missing_layers(layer: ExportableQuantModule, layer_name: str,# pylint: disable=too-many-branches + def _get_torch_encodings_for_missing_layers(layer: QuantizedModuleProtocol, layer_name: str,# pylint: disable=too-many-branches missing_activation_encodings_torch: Dict, missing_param_encodings: Dict, valid_param_set: set): @@ -893,7 +893,7 @@ def _get_torch_encodings_for_missing_layers(layer: ExportableQuantModule, layer_ :param missing_param_encodings: dictionary of param encodings :param valid_param_set: a set of valid param input names in model """ - if isinstance(layer, ExportableQuantModule): + if isinstance(layer, QuantizedModuleProtocol): # -------------------------------------- # Update encodings for Input activations # -------------------------------------- @@ -969,14 +969,14 @@ def _export_encodings_to_files(sim_model: torch.nn.Module, path: str, filename_p tensor_to_quantizer_map = {} for layer_name, layer in sim_model.named_modules(): - if not isinstance(layer, (ExportableQuantModule, QcQuantizeRecurrent)): + if not isinstance(layer, (QuantizedModuleProtocol, QcQuantizeRecurrent)): continue if not has_valid_encodings(layer): continue # TODO: specifically call out dropout layers here since they are specifically switched out during export. # These ops should eventually be reworked as part of math invariant ops to ignore quantization altogether. # pylint: disable=protected-access - if isinstance(layer, ExportableQuantModule) and isinstance(layer.get_original_module(), utils.DROPOUT_TYPES): + if isinstance(layer, QuantizedModuleProtocol) and isinstance(layer.get_original_module(), utils.DROPOUT_TYPES): continue if layer_name not in layers_to_onnx_op_names.keys(): @@ -1042,7 +1042,7 @@ def _export_encodings_to_files(sim_model: torch.nn.Module, path: str, filename_p save_json_yaml(encoding_file_path_pytorch, encodings_dict_pytorch) @staticmethod - def _update_param_encodings_dict_for_layer(layer: ExportableQuantModule, layer_name: str, param_encodings: Dict, + def _update_param_encodings_dict_for_layer(layer: QuantizedModuleProtocol, layer_name: str, param_encodings: Dict, valid_param_set: set, tensor_to_quantizer_map: Dict): """ :param layer: layer as torch.nn.Module @@ -1062,7 +1062,7 @@ def _update_param_encodings_dict_for_layer(layer: ExportableQuantModule, layer_n tensor_to_quantizer_map[param_name] = layer.param_quantizers[orig_param_name] @staticmethod - def _update_encoding_dicts_for_layer(layer: ExportableQuantModule, layer_name: str, activation_encodings_onnx: Dict, + def _update_encoding_dicts_for_layer(layer: QuantizedModuleProtocol, layer_name: str, activation_encodings_onnx: Dict, activation_encodings_torch: Dict, param_encodings: Dict, op_to_io_tensor_map: Dict, valid_param_set: set, propagate_encodings: bool, tensor_to_consumer_map: Dict[str, str], @@ -1084,7 +1084,7 @@ def _update_encoding_dicts_for_layer(layer: ExportableQuantModule, layer_name: s :param layers_to_onnx_op_names: Dictionary mapping PyTorch layer names to names of corresponding ONNX ops """ - if isinstance(layer, ExportableQuantModule): + if isinstance(layer, QuantizedModuleProtocol): # -------------------------------------- # Update encodings for Input activations @@ -1167,7 +1167,7 @@ def find_op_names_for_layer(layer_name: str, op_to_io_tensor_map: Dict, return end_op_names, op_names @staticmethod - def _update_encoding_dict_for_output_activations(layer: ExportableQuantModule, layer_name: str, op_to_io_tensor_map: Dict, + def _update_encoding_dict_for_output_activations(layer: QuantizedModuleProtocol, layer_name: str, op_to_io_tensor_map: Dict, activation_encodings_onnx: Dict, activation_encodings_torch: Dict, propagate_encodings: bool, tensor_to_consumer_map: Dict[str, str], layers_to_onnx_op_names: Dict[str, str], @@ -1204,7 +1204,7 @@ def _update_encoding_dict_for_output_activations(layer: ExportableQuantModule, l @staticmethod - def _update_encoding_dict_for_input_activations(layer: ExportableQuantModule, layer_name: str, op_to_io_tensor_map: Dict, + def _update_encoding_dict_for_input_activations(layer: QuantizedModuleProtocol, layer_name: str, op_to_io_tensor_map: Dict, activation_encodings_onnx: Dict, activation_encodings_torch: Dict, layers_to_onnx_op_names: Dict[str, str], tensor_to_quantizer_map: Dict): @@ -1393,7 +1393,7 @@ def _update_encoding_dict_for_recurrent_layers(layer: torch.nn.Module, layer_nam def _get_qc_quantized_layers(model) -> List[Tuple[str, QcQuantizeWrapper]]: quantized_layers = [] for name, module in model.named_modules(): - if isinstance(module, (QcQuantizeRecurrent, LazyQuantizeWrapper, ExportableQuantModule)): + if isinstance(module, (QcQuantizeRecurrent, LazyQuantizeWrapper, QuantizedModuleProtocol)): quantized_layers.append((name, module)) return quantized_layers @@ -1500,7 +1500,7 @@ def _remove_quantization_wrappers(cls, starting_module, list_of_modules_to_exclu # If modules is in the exclude list, remove the wrapper if module_ref in list_of_modules_to_exclude: - if isinstance(module_ref, ExportableQuantModule): + if isinstance(module_ref, QuantizedModuleProtocol): # Remove the wrapper, gets auto-deleted # pylint: disable=protected-access setattr(starting_module, module_name, module_ref.get_original_module()) @@ -1568,12 +1568,12 @@ def _update_parameters_by_attr(module: torch.nn.Module): def _get_leaf_module_to_name_map(self): """ - Returns a mapping from leaf modules to module name, where any ExportableQuantModule is considered a leaf module, + Returns a mapping from leaf modules to module name, where any QuantizedModuleProtocol is considered a leaf module, and is therefore not further recursed (since we do not want to retrieve all internal quantizers/modules). """ def recursively_populate_map(starting_module, module_map, start_str): for name, module in starting_module.named_children(): - if isinstance(module, ExportableQuantModule) or utils.is_leaf_module(module): + if isinstance(module, QuantizedModuleProtocol) or utils.is_leaf_module(module): module_map[module] = start_str + name else: recursively_populate_map(module, module_map, start_str + name + ".") @@ -1590,7 +1590,7 @@ def inputs_hook(module_ref, inputs, _): hooks[module_ref].remove() del hooks[module_ref] module_name = module_to_name_map[module_ref] - if isinstance(module_ref, ExportableQuantModule): + if isinstance(module_ref, QuantizedModuleProtocol): module_ref = module_ref.get_original_module() marker_layer = torch.jit.trace(CustomMarker(module_ref, module_name, 'True'), inputs) @@ -1780,7 +1780,7 @@ def _set_param_encodings(self, requires_grad: Optional[bool], allow_overwrite: bool): for name, quant_module in self.model.named_modules(): - if isinstance(quant_module, ExportableQuantModule): + if isinstance(quant_module, QuantizedModuleProtocol): param_encoding = { param_name: encoding_dict[f'{name}.{param_name}'] for param_name, _ in quant_module.param_quantizers.items() @@ -1803,7 +1803,7 @@ def _set_activation_encodings(self, requires_grad: Optional[bool], allow_overwrite: bool): for module_name, module in self.model.named_modules(): - if not isinstance(module, ExportableQuantModule): + if not isinstance(module, QuantizedModuleProtocol): continue try: @@ -1858,7 +1858,7 @@ def named_qmodules(self): """Generator that yields all quantized modules in the model and their names """ for name, module in self.model.named_modules(): - if isinstance(module, (QcQuantizeRecurrent, LazyQuantizeWrapper, ExportableQuantModule)): + if isinstance(module, (QcQuantizeRecurrent, LazyQuantizeWrapper, QuantizedModuleProtocol)): yield name, module def qmodules(self): @@ -1883,7 +1883,7 @@ def run_modules_for_traced_custom_marker(self, module_list: List[torch.nn.Module # Only perform init and trace if the given module is a leaf module, and we have not recorded it before if module in module_to_name_map and module_to_name_map[module] not in self._module_marker_map: name = module_to_name_map[module] - module = module.get_original_module() if isinstance(module, ExportableQuantModule) else module + module = module.get_original_module() if isinstance(module, QuantizedModuleProtocol) else module with utils.in_eval_mode(module), torch.no_grad(): marker_layer = torch.jit.trace(CustomMarker(module, name, True), dummy_input) self._module_marker_map[name] = marker_layer @@ -2299,18 +2299,18 @@ def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, pytorch_encodin quant_sim_model.replace_wrappers_for_quantize_dequantize() -def has_valid_encodings(qc_quantize_op: ExportableQuantModule) -> bool: +def has_valid_encodings(qc_quantize_op: QuantizedModuleProtocol) -> bool: """ Utility for determining whether a given qc_quantize_op has any valid encodings. :param qc_quantize_op: Qc quantize op to evaluate :return: True if any input, param, or output quantizers have valid encodings, False otherwise """ - if not isinstance(qc_quantize_op, (ExportableQuantModule, QcQuantizeRecurrent)): + if not isinstance(qc_quantize_op, (QuantizedModuleProtocol, QcQuantizeRecurrent)): logger.error("has_valid_encodings only supported for QcQuantizeWrapper and QcQuantizeRecurrent " "modules") - assert isinstance(qc_quantize_op, (ExportableQuantModule, QcQuantizeRecurrent)) - if isinstance(qc_quantize_op, ExportableQuantModule): + assert isinstance(qc_quantize_op, (QuantizedModuleProtocol, QcQuantizeRecurrent)) + if isinstance(qc_quantize_op, QuantizedModuleProtocol): all_encodings = qc_quantize_op.export_output_encodings() + qc_quantize_op.export_input_encodings() + \ list(qc_quantize_op.export_param_encodings().values()) return any([encoding is not None for encoding in all_encodings]) # pylint: disable=consider-using-generator,use-a-generator