From ef829a95730d696b60ebef41212385b238df7f7e Mon Sep 17 00:00:00 2001 From: Kyunggeun Lee Date: Mon, 11 Nov 2024 21:45:56 -0800 Subject: [PATCH] Fix pylint error (#3473) Signed-off-by: Kyunggeun Lee --- .../mixed_precision/manual_mixed_precision.py | 2 +- .../aimet_torch/v2/mixed_precision/utils.py | 18 ++++++++++-------- .../affine/backends/torch_builtins.py | 6 +++--- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/mixed_precision/manual_mixed_precision.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/mixed_precision/manual_mixed_precision.py index b8fd748e8f2..f89bb3ec215 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/mixed_precision/manual_mixed_precision.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/mixed_precision/manual_mixed_precision.py @@ -109,7 +109,7 @@ def set_precision(self, arg: Union[torch.nn.Module, Type[torch.nn.Module]], Examples: TODO """ - + # pylint: disable=too-many-branches if activation: if isinstance(activation, List): for act in activation: diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/mixed_precision/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/mixed_precision/utils.py index aeff5bb8692..280fe6449ef 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/mixed_precision/utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/mixed_precision/utils.py @@ -37,9 +37,11 @@ # ============================================================================= """Utilities to achieve mixed precision""" -from dataclasses import dataclass, field +# pylint: disable=logging-fstring-interpolation + +from dataclasses import dataclass from enum import Enum -from typing import Dict, Type, List, TypeAlias, Literal, Tuple, Optional, Union, Generator +from typing import Dict, Type, List, TypeAlias, Literal, Optional, Union, Generator import functools import torch @@ -66,10 +68,9 @@ class Precision: def __lt__(self, other): if self == other: return False - elif self.bitwidth != other.bitwidth: + if self.bitwidth != other.bitwidth: return self.bitwidth < other.bitwidth - else: - return self.data_type == QuantizationDataType.int and other.data_type != QuantizationDataType.int + return self.data_type == QuantizationDataType.int and other.data_type != QuantizationDataType.int TranslateUserDtypes = { @@ -226,6 +227,7 @@ def create_mp_request(torch_module: BaseQuantizationMixin, module_name: str, idx raise RuntimeError(f"Unsupported request type {user_request.request_type} encountered") return mp_requests + # pylint: disable=unused-argument, no-self-use def _apply_backend_awareness(self, mp_requests: Dict, config: str = "", strict: bool = True) -> Dict: """ Apply backend awareness to the requests from the user @@ -287,7 +289,7 @@ def _get_module_from_cg_op(self, cg_op: CG_Op) -> Optional[torch.nn.Module]: if module is None: return None - fully_qualified_name = self._sim.connected_graph._module_to_name[module] + fully_qualified_name = self._sim.connected_graph._module_to_name[module] # pylint: disable=protected-access _, name = fully_qualified_name.split('.', maxsplit=1) quant_module = _rgetattr(self._sim.model, name) return quant_module @@ -451,7 +453,7 @@ def _propagate_request_upstream_helper(module): parent_module = self._get_parent_module_at_input_idx(module, in_idx) if parent_module is None: logger.warning(f"Warning: unable to propagate request at {module} upward. " - f"Parent module could not be found.") + "Parent module could not be found.") continue # TODO: remove this once ops with multiple outputs are supported @@ -547,7 +549,7 @@ def _apply_requests_to_sim(self, mp_requests: Dict): request.output_candidates[idx]) def apply(self, user_requests: Dict[int, UserRequest], config: str = "", strict: bool = True, - log_file: str = './mmp_log.txt'): + log_file: str = './mmp_log.txt'): # pylint: disable=unused-argument """ Apply the mp settings specified through the set_precision/set_model_input_precision/set_model_output_precision calls to the QuantSim object diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/backends/torch_builtins.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/backends/torch_builtins.py index 503a04c4350..1fb9aa61505 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/backends/torch_builtins.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/backends/torch_builtins.py @@ -108,9 +108,9 @@ def _validate_arguments(tensor: torch.Tensor, scale: torch.Tensor, msg = f"Scale of shape {scale.shape} cannot be expanded like input tensor of shape {tensor.shape}. " # Additional message if the tensor is empty if tensor.numel() == 0: - msg += (f"Detected that the tensor is empty, which may be caused by the following reasons: " - f"1. The input tensor is incorrect. " - f"2. Improper use of model inference without initializing DeepSpeed after offloading parameters.") + msg += ("Detected that the tensor is empty, which may be caused by the following reasons: " + "1. The input tensor is incorrect. " + "2. Improper use of model inference without initializing DeepSpeed after offloading parameters.") raise RuntimeError(msg) if qmin is not None and qmax is not None: