diff --git a/direct/nn/transformers/transformers.py b/direct/nn/transformers/transformers.py index c42af1ad..4db948d8 100644 --- a/direct/nn/transformers/transformers.py +++ b/direct/nn/transformers/transformers.py @@ -80,7 +80,6 @@ class ImageDomainMRIUFormer(nn.Module): Whether to apply normalization before and denormalization after the forward pass. Default: True. """ - # pylint: disable=too-many-arguments def __init__( self, forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], @@ -163,8 +162,9 @@ def __init__( Whether to apply normalization before and denormalization after the forward pass. Default: True. **kwargs: Other keyword arguments to pass to the parent constructor. """ + # pylint: disable=too-many-arguments super().__init__() - for extra_key in kwargs.keys(): + for extra_key in kwargs: if extra_key not in [ "model_name", ]: @@ -332,7 +332,7 @@ def __init__( Whether to normalize the input tensor. Default: True. """ super().__init__() - for extra_key in kwargs.keys(): + for extra_key in kwargs: if extra_key not in [ "model_name", ]: @@ -490,7 +490,7 @@ def __init__( Whether to normalize the input tensor. Default: True. """ super().__init__() - for extra_key in kwargs.keys(): + for extra_key in kwargs: if extra_key not in [ "model_name", ]: @@ -654,7 +654,7 @@ def __init__( Whether to compute the output per coil. """ super().__init__() - for extra_key in kwargs.keys(): + for extra_key in kwargs: if extra_key not in [ "model_name", ]: @@ -849,7 +849,7 @@ def __init__( Whether to compute the output per coil. """ super().__init__() - for extra_key in kwargs.keys(): + for extra_key in kwargs: if extra_key not in [ "model_name", ]: diff --git a/direct/nn/transformers/uformer.py b/direct/nn/transformers/uformer.py index 2693544f..473f6599 100644 --- a/direct/nn/transformers/uformer.py +++ b/direct/nn/transformers/uformer.py @@ -420,6 +420,7 @@ def __init__( proj_drop : float Dropout rate for the output of the last linear projection layer. """ + # pylint: disable=too-many-locals super().__init__() self.dim = dim self.win_size = win_size # Wh, Ww @@ -1277,6 +1278,7 @@ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch ------- torch.Tensor """ + # pylint: disable=too-many-locals B, L, C = x.shape H = int(math.sqrt(L)) W = int(math.sqrt(L)) diff --git a/direct/nn/transformers/vit.py b/direct/nn/transformers/vit.py index ebbc2ce7..dd3daff1 100644 --- a/direct/nn/transformers/vit.py +++ b/direct/nn/transformers/vit.py @@ -827,6 +827,7 @@ def __init__( normalized : bool Whether to normalize the input tensor. Default: True. """ + # pylint: disable=too-many-locals super().__init__() self.dimensionality = dimensionality