Skip to content

Commit

Permalink
Rename ExportableQuantModule to QuantizedModuleProtocol
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu committed Oct 31, 2024
1 parent f1707c0 commit aac3703
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from aimet_torch.gptvq.defs import GPTVQSupportedModules, GPTVQParameters
from aimet_torch.gptvq.gptvq_optimizer import GPTVQOptimizer
from aimet_torch.gptvq.utils import get_module_name_to_hessian_tensor
from aimet_torch.v1.quantsim import ExportableQuantModule
from aimet_torch.v1.quantsim import QuantizedModuleProtocol
from aimet_torch.save_utils import SaveUtils
from aimet_torch.utils import get_named_module
from aimet_torch.v2.nn import BaseQuantizationMixin
Expand Down Expand Up @@ -391,7 +391,7 @@ def _export_encodings_to_json(cls,
json.dump(encoding, encoding_fp, sort_keys=True, indent=4)

@staticmethod
def _update_param_encodings_dict(quant_module: ExportableQuantModule,
def _update_param_encodings_dict(quant_module: QuantizedModuleProtocol,
name: str,
param_encodings: Dict,
rows_per_block: int):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from aimet_common.utils import AimetLogger
from aimet_common.layer_output_utils import SaveInputOutput, save_layer_output_names

from aimet_torch.v1.quantsim import ExportableQuantModule, QuantizationSimModel
from aimet_torch.v1.quantsim import QuantizedModuleProtocol, QuantizationSimModel
from aimet_torch import utils
from aimet_torch import torchscript_utils
from aimet_torch.onnx_utils import OnnxSaver, OnnxExportApiArgs
Expand Down Expand Up @@ -171,7 +171,7 @@ def __init__(self, model: torch.nn.Module, dir_path: str, naming_scheme: NamingS
self.module_to_name_dict = utils.get_module_to_name_dict(model=model, prefix='')

# Check whether the given model is quantsim model
self.is_quantsim_model = any(isinstance(module, (ExportableQuantModule, QcQuantizeRecurrent)) for module in model.modules())
self.is_quantsim_model = any(isinstance(module, (QuantizedModuleProtocol, QcQuantizeRecurrent)) for module in model.modules())

# Obtain layer-name to layer-output name mapping
self.layer_name_to_layer_output_dict = {}
Expand Down Expand Up @@ -206,7 +206,7 @@ def get_outputs(self, input_batch: Union[torch.Tensor, List[torch.Tensor], Tuple
if self.is_quantsim_model:
# Apply record-output hook to QuantizeWrapper modules (one node above leaf node in model graph)
utils.run_hook_for_layers_with_given_input(self.model, input_batch, self.record_outputs,
module_type_for_attaching_hook=(ExportableQuantModule, QcQuantizeRecurrent),
module_type_for_attaching_hook=(QuantizedModuleProtocol, QcQuantizeRecurrent),
leaf_node_only=False)
else:
# Apply record-output hook to Original modules (leaf node in model graph)
Expand Down
4 changes: 2 additions & 2 deletions TrainingExtensions/torch/src/python/aimet_torch/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from aimet_torch.v1.nn.modules.custom import Add, Multiply
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v2.quantization.affine import QuantizeDequantize
from aimet_torch.v1.quantsim import ExportableQuantModule
from aimet_torch.v1.quantsim import QuantizedModuleProtocol
from aimet_torch.v2.nn import BaseQuantizationMixin


Expand Down Expand Up @@ -396,7 +396,7 @@ def export_adapter_weights(self, sim: QuantizationSimModel, path: str, filename_
tensors = {}

for module_name, module in sim.model.named_modules():
if not isinstance(module, ExportableQuantModule):
if not isinstance(module, QuantizedModuleProtocol):
continue
org_name = module_name
pt_name = self._get_module_name(module_name)
Expand Down
4 changes: 2 additions & 2 deletions TrainingExtensions/torch/src/python/aimet_torch/save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

""" Utilities to save a models and related parameters """

from aimet_torch.v1.quantsim import ExportableQuantModule
from aimet_torch.v1.quantsim import QuantizedModuleProtocol


class SaveUtils:
Expand All @@ -50,7 +50,7 @@ def remove_quantization_wrappers(module):
:param module: Model
"""
for module_name, module_ref in module.named_children():
if isinstance(module_ref, ExportableQuantModule):
if isinstance(module_ref, QuantizedModuleProtocol):
setattr(module, module_name, module_ref.get_original_module())
# recursively call children modules
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from aimet_torch import utils
from aimet_torch.save_utils import SaveUtils
from aimet_torch.meta import connectedgraph_utils
from aimet_torch.v1.quantsim import QuantizationSimModel, QcQuantizeWrapper, ExportableQuantModule
from aimet_torch.v1.quantsim import QuantizationSimModel, QcQuantizeWrapper, QuantizedModuleProtocol
from aimet_torch.v1.qc_quantize_op import StaticGridQuantWrapper, QcQuantizeOpMode
from aimet_torch.v1.tensor_quantizer import TensorQuantizer
from aimet_torch.v1.adaround.adaround_wrapper import AdaroundWrapper
Expand Down Expand Up @@ -490,7 +490,7 @@ def _export_encodings_to_json(cls, path: str, filename_prefix: str, quant_sim: Q
param_encodings = {}

for name, quant_module in quant_sim.model.named_modules():
if isinstance(quant_module, ExportableQuantModule) and \
if isinstance(quant_module, QuantizedModuleProtocol) and \
isinstance(quant_module.get_original_module(), AdaroundSupportedModules):

if 'weight' in quant_module.param_quantizers:
Expand All @@ -505,7 +505,7 @@ def _export_encodings_to_json(cls, path: str, filename_prefix: str, quant_sim: Q
json.dump(encoding, encoding_fp, sort_keys=True, indent=4)

@classmethod
def _update_param_encodings_dict(cls, quant_module: ExportableQuantModule, name: str, param_encodings: Dict):
def _update_param_encodings_dict(cls, quant_module: QuantizedModuleProtocol, name: str, param_encodings: Dict):
"""
Add module's weight parameter encodings to dictionary to be used for exporting encodings
:param quant_module: quant module
Expand Down Expand Up @@ -560,7 +560,7 @@ def _override_param_bitwidth(model: torch.nn.Module, quant_sim: QuantizationSimM
# Create a mapping of QuantSim model's AdaRoundable module name and their module
name_to_module = {}
for q_name, q_module in quant_sim.model.named_modules():
if isinstance(q_module, ExportableQuantModule):
if isinstance(q_module, QuantizedModuleProtocol):
if isinstance(q_module.get_original_module(), AdaroundSupportedModules): # pylint: disable=protected-access
name_to_module[q_name] = q_module

Expand Down
52 changes: 26 additions & 26 deletions TrainingExtensions/torch/src/python/aimet_torch/v1/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self,


@runtime_checkable
class ExportableQuantModule(Protocol):
class QuantizedModuleProtocol(Protocol):
"""
Defines the minimum interface requirements for exporting encodings from a module.
"""
Expand Down Expand Up @@ -210,7 +210,7 @@ def get_original_module(self) -> torch.nn.Module:
QcQuantizeWrapper,
QcQuantizeStandAloneBase,
QcQuantizeRecurrent,
ExportableQuantModule,
QuantizedModuleProtocol,
LazyQuantizeWrapper,
)

Expand Down Expand Up @@ -699,7 +699,7 @@ def get_activation_param_encodings(self):
param_encodings = OrderedDict()

for module_name, module in self.model.named_modules():
if not isinstance(module, ExportableQuantModule):
if not isinstance(module, QuantizedModuleProtocol):
continue

activation_encodings[module_name] = defaultdict(OrderedDict)
Expand Down Expand Up @@ -740,7 +740,7 @@ def exclude_layers_from_quantization(self, layers_to_exclude: List[torch.nn.Modu
quant_layers_to_exclude = []
quant_cls = (QcQuantizeRecurrent,
LazyQuantizeWrapper,
ExportableQuantModule)
QuantizedModuleProtocol)
for layer in layers_to_exclude:
for module in layer.modules():
if isinstance(module, quant_cls):
Expand Down Expand Up @@ -881,7 +881,7 @@ def _find_next_downstream_modules(op):


@staticmethod
def _get_torch_encodings_for_missing_layers(layer: ExportableQuantModule, layer_name: str,# pylint: disable=too-many-branches
def _get_torch_encodings_for_missing_layers(layer: QuantizedModuleProtocol, layer_name: str,# pylint: disable=too-many-branches
missing_activation_encodings_torch: Dict,
missing_param_encodings: Dict,
valid_param_set: set):
Expand All @@ -893,7 +893,7 @@ def _get_torch_encodings_for_missing_layers(layer: ExportableQuantModule, layer_
:param missing_param_encodings: dictionary of param encodings
:param valid_param_set: a set of valid param input names in model
"""
if isinstance(layer, ExportableQuantModule):
if isinstance(layer, QuantizedModuleProtocol):
# --------------------------------------
# Update encodings for Input activations
# --------------------------------------
Expand Down Expand Up @@ -969,14 +969,14 @@ def _export_encodings_to_files(sim_model: torch.nn.Module, path: str, filename_p
tensor_to_quantizer_map = {}

for layer_name, layer in sim_model.named_modules():
if not isinstance(layer, (ExportableQuantModule, QcQuantizeRecurrent)):
if not isinstance(layer, (QuantizedModuleProtocol, QcQuantizeRecurrent)):
continue
if not has_valid_encodings(layer):
continue
# TODO: specifically call out dropout layers here since they are specifically switched out during export.
# These ops should eventually be reworked as part of math invariant ops to ignore quantization altogether.
# pylint: disable=protected-access
if isinstance(layer, ExportableQuantModule) and isinstance(layer.get_original_module(), utils.DROPOUT_TYPES):
if isinstance(layer, QuantizedModuleProtocol) and isinstance(layer.get_original_module(), utils.DROPOUT_TYPES):
continue

if layer_name not in layers_to_onnx_op_names.keys():
Expand Down Expand Up @@ -1042,7 +1042,7 @@ def _export_encodings_to_files(sim_model: torch.nn.Module, path: str, filename_p
save_json_yaml(encoding_file_path_pytorch, encodings_dict_pytorch)

@staticmethod
def _update_param_encodings_dict_for_layer(layer: ExportableQuantModule, layer_name: str, param_encodings: Dict,
def _update_param_encodings_dict_for_layer(layer: QuantizedModuleProtocol, layer_name: str, param_encodings: Dict,
valid_param_set: set, tensor_to_quantizer_map: Dict):
"""
:param layer: layer as torch.nn.Module
Expand All @@ -1062,7 +1062,7 @@ def _update_param_encodings_dict_for_layer(layer: ExportableQuantModule, layer_n
tensor_to_quantizer_map[param_name] = layer.param_quantizers[orig_param_name]

@staticmethod
def _update_encoding_dicts_for_layer(layer: ExportableQuantModule, layer_name: str, activation_encodings_onnx: Dict,
def _update_encoding_dicts_for_layer(layer: QuantizedModuleProtocol, layer_name: str, activation_encodings_onnx: Dict,
activation_encodings_torch: Dict, param_encodings: Dict,
op_to_io_tensor_map: Dict, valid_param_set: set, propagate_encodings: bool,
tensor_to_consumer_map: Dict[str, str],
Expand All @@ -1084,7 +1084,7 @@ def _update_encoding_dicts_for_layer(layer: ExportableQuantModule, layer_name: s
:param layers_to_onnx_op_names: Dictionary mapping PyTorch layer names to names of corresponding ONNX ops
"""

if isinstance(layer, ExportableQuantModule):
if isinstance(layer, QuantizedModuleProtocol):

# --------------------------------------
# Update encodings for Input activations
Expand Down Expand Up @@ -1167,7 +1167,7 @@ def find_op_names_for_layer(layer_name: str, op_to_io_tensor_map: Dict,
return end_op_names, op_names

@staticmethod
def _update_encoding_dict_for_output_activations(layer: ExportableQuantModule, layer_name: str, op_to_io_tensor_map: Dict,
def _update_encoding_dict_for_output_activations(layer: QuantizedModuleProtocol, layer_name: str, op_to_io_tensor_map: Dict,
activation_encodings_onnx: Dict, activation_encodings_torch: Dict,
propagate_encodings: bool, tensor_to_consumer_map: Dict[str, str],
layers_to_onnx_op_names: Dict[str, str],
Expand Down Expand Up @@ -1204,7 +1204,7 @@ def _update_encoding_dict_for_output_activations(layer: ExportableQuantModule, l


@staticmethod
def _update_encoding_dict_for_input_activations(layer: ExportableQuantModule, layer_name: str, op_to_io_tensor_map: Dict,
def _update_encoding_dict_for_input_activations(layer: QuantizedModuleProtocol, layer_name: str, op_to_io_tensor_map: Dict,
activation_encodings_onnx: Dict, activation_encodings_torch: Dict,
layers_to_onnx_op_names: Dict[str, str],
tensor_to_quantizer_map: Dict):
Expand Down Expand Up @@ -1393,7 +1393,7 @@ def _update_encoding_dict_for_recurrent_layers(layer: torch.nn.Module, layer_nam
def _get_qc_quantized_layers(model) -> List[Tuple[str, QcQuantizeWrapper]]:
quantized_layers = []
for name, module in model.named_modules():
if isinstance(module, (QcQuantizeRecurrent, LazyQuantizeWrapper, ExportableQuantModule)):
if isinstance(module, (QcQuantizeRecurrent, LazyQuantizeWrapper, QuantizedModuleProtocol)):
quantized_layers.append((name, module))
return quantized_layers

Expand Down Expand Up @@ -1500,7 +1500,7 @@ def _remove_quantization_wrappers(cls, starting_module, list_of_modules_to_exclu
# If modules is in the exclude list, remove the wrapper
if module_ref in list_of_modules_to_exclude:

if isinstance(module_ref, ExportableQuantModule):
if isinstance(module_ref, QuantizedModuleProtocol):
# Remove the wrapper, gets auto-deleted
# pylint: disable=protected-access
setattr(starting_module, module_name, module_ref.get_original_module())
Expand Down Expand Up @@ -1568,12 +1568,12 @@ def _update_parameters_by_attr(module: torch.nn.Module):

def _get_leaf_module_to_name_map(self):
"""
Returns a mapping from leaf modules to module name, where any ExportableQuantModule is considered a leaf module,
Returns a mapping from leaf modules to module name, where any QuantizedModuleProtocol is considered a leaf module,
and is therefore not further recursed (since we do not want to retrieve all internal quantizers/modules).
"""
def recursively_populate_map(starting_module, module_map, start_str):
for name, module in starting_module.named_children():
if isinstance(module, ExportableQuantModule) or utils.is_leaf_module(module):
if isinstance(module, QuantizedModuleProtocol) or utils.is_leaf_module(module):
module_map[module] = start_str + name
else:
recursively_populate_map(module, module_map, start_str + name + ".")
Expand All @@ -1590,7 +1590,7 @@ def inputs_hook(module_ref, inputs, _):
hooks[module_ref].remove()
del hooks[module_ref]
module_name = module_to_name_map[module_ref]
if isinstance(module_ref, ExportableQuantModule):
if isinstance(module_ref, QuantizedModuleProtocol):
module_ref = module_ref.get_original_module()
marker_layer = torch.jit.trace(CustomMarker(module_ref, module_name, 'True'),
inputs)
Expand Down Expand Up @@ -1780,7 +1780,7 @@ def _set_param_encodings(self,
requires_grad: Optional[bool],
allow_overwrite: bool):
for name, quant_module in self.model.named_modules():
if isinstance(quant_module, ExportableQuantModule):
if isinstance(quant_module, QuantizedModuleProtocol):
param_encoding = {
param_name: encoding_dict[f'{name}.{param_name}']
for param_name, _ in quant_module.param_quantizers.items()
Expand All @@ -1803,7 +1803,7 @@ def _set_activation_encodings(self,
requires_grad: Optional[bool],
allow_overwrite: bool):
for module_name, module in self.model.named_modules():
if not isinstance(module, ExportableQuantModule):
if not isinstance(module, QuantizedModuleProtocol):
continue

try:
Expand Down Expand Up @@ -1858,7 +1858,7 @@ def named_qmodules(self):
"""Generator that yields all quantized modules in the model and their names
"""
for name, module in self.model.named_modules():
if isinstance(module, (QcQuantizeRecurrent, LazyQuantizeWrapper, ExportableQuantModule)):
if isinstance(module, (QcQuantizeRecurrent, LazyQuantizeWrapper, QuantizedModuleProtocol)):
yield name, module

def qmodules(self):
Expand All @@ -1883,7 +1883,7 @@ def run_modules_for_traced_custom_marker(self, module_list: List[torch.nn.Module
# Only perform init and trace if the given module is a leaf module, and we have not recorded it before
if module in module_to_name_map and module_to_name_map[module] not in self._module_marker_map:
name = module_to_name_map[module]
module = module.get_original_module() if isinstance(module, ExportableQuantModule) else module
module = module.get_original_module() if isinstance(module, QuantizedModuleProtocol) else module
with utils.in_eval_mode(module), torch.no_grad():
marker_layer = torch.jit.trace(CustomMarker(module, name, True), dummy_input)
self._module_marker_map[name] = marker_layer
Expand Down Expand Up @@ -2299,18 +2299,18 @@ def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, pytorch_encodin
quant_sim_model.replace_wrappers_for_quantize_dequantize()


def has_valid_encodings(qc_quantize_op: ExportableQuantModule) -> bool:
def has_valid_encodings(qc_quantize_op: QuantizedModuleProtocol) -> bool:
"""
Utility for determining whether a given qc_quantize_op has any valid encodings.
:param qc_quantize_op: Qc quantize op to evaluate
:return: True if any input, param, or output quantizers have valid encodings, False otherwise
"""
if not isinstance(qc_quantize_op, (ExportableQuantModule, QcQuantizeRecurrent)):
if not isinstance(qc_quantize_op, (QuantizedModuleProtocol, QcQuantizeRecurrent)):
logger.error("has_valid_encodings only supported for QcQuantizeWrapper and QcQuantizeRecurrent "
"modules")
assert isinstance(qc_quantize_op, (ExportableQuantModule, QcQuantizeRecurrent))
if isinstance(qc_quantize_op, ExportableQuantModule):
assert isinstance(qc_quantize_op, (QuantizedModuleProtocol, QcQuantizeRecurrent))
if isinstance(qc_quantize_op, QuantizedModuleProtocol):
all_encodings = qc_quantize_op.export_output_encodings() + qc_quantize_op.export_input_encodings() + \
list(qc_quantize_op.export_param_encodings().values())
return any([encoding is not None for encoding in all_encodings]) # pylint: disable=consider-using-generator,use-a-generator
Expand Down

0 comments on commit aac3703

Please sign in to comment.