Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename ExportableQuantModule to QuantizedModuleProtocol #3455

Merged
merged 1 commit into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from aimet_torch.gptvq.defs import GPTVQSupportedModules, GPTVQParameters
from aimet_torch.gptvq.gptvq_optimizer import GPTVQOptimizer
from aimet_torch.gptvq.utils import get_module_name_to_hessian_tensor
from aimet_torch.v1.quantsim import ExportableQuantModule
from aimet_torch.v1.quantsim import QuantizedModuleProtocol
from aimet_torch.save_utils import SaveUtils
from aimet_torch.utils import get_named_module
from aimet_torch.v2.nn import BaseQuantizationMixin
Expand Down Expand Up @@ -391,7 +391,7 @@ def _export_encodings_to_json(cls,
json.dump(encoding, encoding_fp, sort_keys=True, indent=4)

@staticmethod
def _update_param_encodings_dict(quant_module: ExportableQuantModule,
def _update_param_encodings_dict(quant_module: QuantizedModuleProtocol,
name: str,
param_encodings: Dict,
rows_per_block: int):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from aimet_common.utils import AimetLogger
from aimet_common.layer_output_utils import SaveInputOutput, save_layer_output_names

from aimet_torch.v1.quantsim import ExportableQuantModule, QuantizationSimModel
from aimet_torch.v1.quantsim import QuantizedModuleProtocol, QuantizationSimModel
from aimet_torch import utils
from aimet_torch import torchscript_utils
from aimet_torch.onnx_utils import OnnxSaver, OnnxExportApiArgs
Expand Down Expand Up @@ -171,7 +171,7 @@ def __init__(self, model: torch.nn.Module, dir_path: str, naming_scheme: NamingS
self.module_to_name_dict = utils.get_module_to_name_dict(model=model, prefix='')

# Check whether the given model is quantsim model
self.is_quantsim_model = any(isinstance(module, (ExportableQuantModule, QcQuantizeRecurrent)) for module in model.modules())
self.is_quantsim_model = any(isinstance(module, (QuantizedModuleProtocol, QcQuantizeRecurrent)) for module in model.modules())

# Obtain layer-name to layer-output name mapping
self.layer_name_to_layer_output_dict = {}
Expand Down Expand Up @@ -206,7 +206,7 @@ def get_outputs(self, input_batch: Union[torch.Tensor, List[torch.Tensor], Tuple
if self.is_quantsim_model:
# Apply record-output hook to QuantizeWrapper modules (one node above leaf node in model graph)
utils.run_hook_for_layers_with_given_input(self.model, input_batch, self.record_outputs,
module_type_for_attaching_hook=(ExportableQuantModule, QcQuantizeRecurrent),
module_type_for_attaching_hook=(QuantizedModuleProtocol, QcQuantizeRecurrent),
leaf_node_only=False)
else:
# Apply record-output hook to Original modules (leaf node in model graph)
Expand Down
4 changes: 2 additions & 2 deletions TrainingExtensions/torch/src/python/aimet_torch/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from aimet_torch.v1.nn.modules.custom import Add, Multiply
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v2.quantization.affine import QuantizeDequantize
from aimet_torch.v1.quantsim import ExportableQuantModule
from aimet_torch.v1.quantsim import QuantizedModuleProtocol
from aimet_torch.v2.nn import BaseQuantizationMixin


Expand Down Expand Up @@ -396,7 +396,7 @@ def export_adapter_weights(self, sim: QuantizationSimModel, path: str, filename_
tensors = {}

for module_name, module in sim.model.named_modules():
if not isinstance(module, ExportableQuantModule):
if not isinstance(module, QuantizedModuleProtocol):
continue
org_name = module_name
pt_name = self._get_module_name(module_name)
Expand Down
4 changes: 2 additions & 2 deletions TrainingExtensions/torch/src/python/aimet_torch/save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

""" Utilities to save a models and related parameters """

from aimet_torch.v1.quantsim import ExportableQuantModule
from aimet_torch.v1.quantsim import QuantizedModuleProtocol


class SaveUtils:
Expand All @@ -50,7 +50,7 @@ def remove_quantization_wrappers(module):
:param module: Model
"""
for module_name, module_ref in module.named_children():
if isinstance(module_ref, ExportableQuantModule):
if isinstance(module_ref, QuantizedModuleProtocol):
setattr(module, module_name, module_ref.get_original_module())
# recursively call children modules
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from aimet_torch import utils
from aimet_torch.save_utils import SaveUtils
from aimet_torch.meta import connectedgraph_utils
from aimet_torch.v1.quantsim import QuantizationSimModel, QcQuantizeWrapper, ExportableQuantModule
from aimet_torch.v1.quantsim import QuantizationSimModel, QcQuantizeWrapper, QuantizedModuleProtocol
from aimet_torch.v1.qc_quantize_op import StaticGridQuantWrapper, QcQuantizeOpMode
from aimet_torch.v1.tensor_quantizer import TensorQuantizer
from aimet_torch.v1.adaround.adaround_wrapper import AdaroundWrapper
Expand Down Expand Up @@ -490,7 +490,7 @@ def _export_encodings_to_json(cls, path: str, filename_prefix: str, quant_sim: Q
param_encodings = {}

for name, quant_module in quant_sim.model.named_modules():
if isinstance(quant_module, ExportableQuantModule) and \
if isinstance(quant_module, QuantizedModuleProtocol) and \
isinstance(quant_module.get_original_module(), AdaroundSupportedModules):

if 'weight' in quant_module.param_quantizers:
Expand All @@ -505,7 +505,7 @@ def _export_encodings_to_json(cls, path: str, filename_prefix: str, quant_sim: Q
json.dump(encoding, encoding_fp, sort_keys=True, indent=4)

@classmethod
def _update_param_encodings_dict(cls, quant_module: ExportableQuantModule, name: str, param_encodings: Dict):
def _update_param_encodings_dict(cls, quant_module: QuantizedModuleProtocol, name: str, param_encodings: Dict):
"""
Add module's weight parameter encodings to dictionary to be used for exporting encodings
:param quant_module: quant module
Expand Down Expand Up @@ -560,7 +560,7 @@ def _override_param_bitwidth(model: torch.nn.Module, quant_sim: QuantizationSimM
# Create a mapping of QuantSim model's AdaRoundable module name and their module
name_to_module = {}
for q_name, q_module in quant_sim.model.named_modules():
if isinstance(q_module, ExportableQuantModule):
if isinstance(q_module, QuantizedModuleProtocol):
if isinstance(q_module.get_original_module(), AdaroundSupportedModules): # pylint: disable=protected-access
name_to_module[q_name] = q_module

Expand Down
52 changes: 26 additions & 26 deletions TrainingExtensions/torch/src/python/aimet_torch/v1/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self,


@runtime_checkable
class ExportableQuantModule(Protocol):
class QuantizedModuleProtocol(Protocol):
"""
Defines the minimum interface requirements for exporting encodings from a module.
"""
Expand Down Expand Up @@ -210,7 +210,7 @@ def get_original_module(self) -> torch.nn.Module:
QcQuantizeWrapper,
QcQuantizeStandAloneBase,
QcQuantizeRecurrent,
ExportableQuantModule,
QuantizedModuleProtocol,
LazyQuantizeWrapper,
)

Expand Down Expand Up @@ -699,7 +699,7 @@ def get_activation_param_encodings(self):
param_encodings = OrderedDict()

for module_name, module in self.model.named_modules():
if not isinstance(module, ExportableQuantModule):
if not isinstance(module, QuantizedModuleProtocol):
continue

activation_encodings[module_name] = defaultdict(OrderedDict)
Expand Down Expand Up @@ -740,7 +740,7 @@ def exclude_layers_from_quantization(self, layers_to_exclude: List[torch.nn.Modu
quant_layers_to_exclude = []
quant_cls = (QcQuantizeRecurrent,
LazyQuantizeWrapper,
ExportableQuantModule)
QuantizedModuleProtocol)
for layer in layers_to_exclude:
for module in layer.modules():
if isinstance(module, quant_cls):
Expand Down Expand Up @@ -881,7 +881,7 @@ def _find_next_downstream_modules(op):


@staticmethod
def _get_torch_encodings_for_missing_layers(layer: ExportableQuantModule, layer_name: str,# pylint: disable=too-many-branches
def _get_torch_encodings_for_missing_layers(layer: QuantizedModuleProtocol, layer_name: str,# pylint: disable=too-many-branches
missing_activation_encodings_torch: Dict,
missing_param_encodings: Dict,
valid_param_set: set):
Expand All @@ -893,7 +893,7 @@ def _get_torch_encodings_for_missing_layers(layer: ExportableQuantModule, layer_
:param missing_param_encodings: dictionary of param encodings
:param valid_param_set: a set of valid param input names in model
"""
if isinstance(layer, ExportableQuantModule):
if isinstance(layer, QuantizedModuleProtocol):
# --------------------------------------
# Update encodings for Input activations
# --------------------------------------
Expand Down Expand Up @@ -969,14 +969,14 @@ def _export_encodings_to_files(sim_model: torch.nn.Module, path: str, filename_p
tensor_to_quantizer_map = {}

for layer_name, layer in sim_model.named_modules():
if not isinstance(layer, (ExportableQuantModule, QcQuantizeRecurrent)):
if not isinstance(layer, (QuantizedModuleProtocol, QcQuantizeRecurrent)):
continue
if not has_valid_encodings(layer):
continue
# TODO: specifically call out dropout layers here since they are specifically switched out during export.
# These ops should eventually be reworked as part of math invariant ops to ignore quantization altogether.
# pylint: disable=protected-access
if isinstance(layer, ExportableQuantModule) and isinstance(layer.get_original_module(), utils.DROPOUT_TYPES):
if isinstance(layer, QuantizedModuleProtocol) and isinstance(layer.get_original_module(), utils.DROPOUT_TYPES):
continue

if layer_name not in layers_to_onnx_op_names.keys():
Expand Down Expand Up @@ -1042,7 +1042,7 @@ def _export_encodings_to_files(sim_model: torch.nn.Module, path: str, filename_p
save_json_yaml(encoding_file_path_pytorch, encodings_dict_pytorch)

@staticmethod
def _update_param_encodings_dict_for_layer(layer: ExportableQuantModule, layer_name: str, param_encodings: Dict,
def _update_param_encodings_dict_for_layer(layer: QuantizedModuleProtocol, layer_name: str, param_encodings: Dict,
valid_param_set: set, tensor_to_quantizer_map: Dict):
"""
:param layer: layer as torch.nn.Module
Expand All @@ -1062,7 +1062,7 @@ def _update_param_encodings_dict_for_layer(layer: ExportableQuantModule, layer_n
tensor_to_quantizer_map[param_name] = layer.param_quantizers[orig_param_name]

@staticmethod
def _update_encoding_dicts_for_layer(layer: ExportableQuantModule, layer_name: str, activation_encodings_onnx: Dict,
def _update_encoding_dicts_for_layer(layer: QuantizedModuleProtocol, layer_name: str, activation_encodings_onnx: Dict,
activation_encodings_torch: Dict, param_encodings: Dict,
op_to_io_tensor_map: Dict, valid_param_set: set, propagate_encodings: bool,
tensor_to_consumer_map: Dict[str, str],
Expand All @@ -1084,7 +1084,7 @@ def _update_encoding_dicts_for_layer(layer: ExportableQuantModule, layer_name: s
:param layers_to_onnx_op_names: Dictionary mapping PyTorch layer names to names of corresponding ONNX ops
"""

if isinstance(layer, ExportableQuantModule):
if isinstance(layer, QuantizedModuleProtocol):

# --------------------------------------
# Update encodings for Input activations
Expand Down Expand Up @@ -1167,7 +1167,7 @@ def find_op_names_for_layer(layer_name: str, op_to_io_tensor_map: Dict,
return end_op_names, op_names

@staticmethod
def _update_encoding_dict_for_output_activations(layer: ExportableQuantModule, layer_name: str, op_to_io_tensor_map: Dict,
def _update_encoding_dict_for_output_activations(layer: QuantizedModuleProtocol, layer_name: str, op_to_io_tensor_map: Dict,
activation_encodings_onnx: Dict, activation_encodings_torch: Dict,
propagate_encodings: bool, tensor_to_consumer_map: Dict[str, str],
layers_to_onnx_op_names: Dict[str, str],
Expand Down Expand Up @@ -1204,7 +1204,7 @@ def _update_encoding_dict_for_output_activations(layer: ExportableQuantModule, l


@staticmethod
def _update_encoding_dict_for_input_activations(layer: ExportableQuantModule, layer_name: str, op_to_io_tensor_map: Dict,
def _update_encoding_dict_for_input_activations(layer: QuantizedModuleProtocol, layer_name: str, op_to_io_tensor_map: Dict,
activation_encodings_onnx: Dict, activation_encodings_torch: Dict,
layers_to_onnx_op_names: Dict[str, str],
tensor_to_quantizer_map: Dict):
Expand Down Expand Up @@ -1393,7 +1393,7 @@ def _update_encoding_dict_for_recurrent_layers(layer: torch.nn.Module, layer_nam
def _get_qc_quantized_layers(model) -> List[Tuple[str, QcQuantizeWrapper]]:
quantized_layers = []
for name, module in model.named_modules():
if isinstance(module, (QcQuantizeRecurrent, LazyQuantizeWrapper, ExportableQuantModule)):
if isinstance(module, (QcQuantizeRecurrent, LazyQuantizeWrapper, QuantizedModuleProtocol)):
quantized_layers.append((name, module))
return quantized_layers

Expand Down Expand Up @@ -1500,7 +1500,7 @@ def _remove_quantization_wrappers(cls, starting_module, list_of_modules_to_exclu
# If modules is in the exclude list, remove the wrapper
if module_ref in list_of_modules_to_exclude:

if isinstance(module_ref, ExportableQuantModule):
if isinstance(module_ref, QuantizedModuleProtocol):
# Remove the wrapper, gets auto-deleted
# pylint: disable=protected-access
setattr(starting_module, module_name, module_ref.get_original_module())
Expand Down Expand Up @@ -1568,12 +1568,12 @@ def _update_parameters_by_attr(module: torch.nn.Module):

def _get_leaf_module_to_name_map(self):
"""
Returns a mapping from leaf modules to module name, where any ExportableQuantModule is considered a leaf module,
Returns a mapping from leaf modules to module name, where any QuantizedModuleProtocol is considered a leaf module,
and is therefore not further recursed (since we do not want to retrieve all internal quantizers/modules).
"""
def recursively_populate_map(starting_module, module_map, start_str):
for name, module in starting_module.named_children():
if isinstance(module, ExportableQuantModule) or utils.is_leaf_module(module):
if isinstance(module, QuantizedModuleProtocol) or utils.is_leaf_module(module):
module_map[module] = start_str + name
else:
recursively_populate_map(module, module_map, start_str + name + ".")
Expand All @@ -1590,7 +1590,7 @@ def inputs_hook(module_ref, inputs, _):
hooks[module_ref].remove()
del hooks[module_ref]
module_name = module_to_name_map[module_ref]
if isinstance(module_ref, ExportableQuantModule):
if isinstance(module_ref, QuantizedModuleProtocol):
module_ref = module_ref.get_original_module()
marker_layer = torch.jit.trace(CustomMarker(module_ref, module_name, 'True'),
inputs)
Expand Down Expand Up @@ -1780,7 +1780,7 @@ def _set_param_encodings(self,
requires_grad: Optional[bool],
allow_overwrite: bool):
for name, quant_module in self.model.named_modules():
if isinstance(quant_module, ExportableQuantModule):
if isinstance(quant_module, QuantizedModuleProtocol):
param_encoding = {
param_name: encoding_dict[f'{name}.{param_name}']
for param_name, _ in quant_module.param_quantizers.items()
Expand All @@ -1803,7 +1803,7 @@ def _set_activation_encodings(self,
requires_grad: Optional[bool],
allow_overwrite: bool):
for module_name, module in self.model.named_modules():
if not isinstance(module, ExportableQuantModule):
if not isinstance(module, QuantizedModuleProtocol):
continue

try:
Expand Down Expand Up @@ -1858,7 +1858,7 @@ def named_qmodules(self):
"""Generator that yields all quantized modules in the model and their names
"""
for name, module in self.model.named_modules():
if isinstance(module, (QcQuantizeRecurrent, LazyQuantizeWrapper, ExportableQuantModule)):
if isinstance(module, (QcQuantizeRecurrent, LazyQuantizeWrapper, QuantizedModuleProtocol)):
yield name, module

def qmodules(self):
Expand All @@ -1883,7 +1883,7 @@ def run_modules_for_traced_custom_marker(self, module_list: List[torch.nn.Module
# Only perform init and trace if the given module is a leaf module, and we have not recorded it before
if module in module_to_name_map and module_to_name_map[module] not in self._module_marker_map:
name = module_to_name_map[module]
module = module.get_original_module() if isinstance(module, ExportableQuantModule) else module
module = module.get_original_module() if isinstance(module, QuantizedModuleProtocol) else module
with utils.in_eval_mode(module), torch.no_grad():
marker_layer = torch.jit.trace(CustomMarker(module, name, True), dummy_input)
self._module_marker_map[name] = marker_layer
Expand Down Expand Up @@ -2299,18 +2299,18 @@ def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, pytorch_encodin
quant_sim_model.replace_wrappers_for_quantize_dequantize()


def has_valid_encodings(qc_quantize_op: ExportableQuantModule) -> bool:
def has_valid_encodings(qc_quantize_op: QuantizedModuleProtocol) -> bool:
"""
Utility for determining whether a given qc_quantize_op has any valid encodings.

:param qc_quantize_op: Qc quantize op to evaluate
:return: True if any input, param, or output quantizers have valid encodings, False otherwise
"""
if not isinstance(qc_quantize_op, (ExportableQuantModule, QcQuantizeRecurrent)):
if not isinstance(qc_quantize_op, (QuantizedModuleProtocol, QcQuantizeRecurrent)):
logger.error("has_valid_encodings only supported for QcQuantizeWrapper and QcQuantizeRecurrent "
"modules")
assert isinstance(qc_quantize_op, (ExportableQuantModule, QcQuantizeRecurrent))
if isinstance(qc_quantize_op, ExportableQuantModule):
assert isinstance(qc_quantize_op, (QuantizedModuleProtocol, QcQuantizeRecurrent))
if isinstance(qc_quantize_op, QuantizedModuleProtocol):
all_encodings = qc_quantize_op.export_output_encodings() + qc_quantize_op.export_input_encodings() + \
list(qc_quantize_op.export_param_encodings().values())
return any([encoding is not None for encoding in all_encodings]) # pylint: disable=consider-using-generator,use-a-generator
Expand Down
Loading