Skip to content

Commit

Permalink
Codacy/pylint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jan 7, 2025
1 parent e93f3df commit 3d6272a
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 23 deletions.
27 changes: 14 additions & 13 deletions direct/nn/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class ImageDomainMRIUFormer(nn.Module):
Whether to apply normalization before and denormalization after the forward pass. Default: True.
"""

# pylint: disable=too-many-arguments
def __init__(
self,
forward_operator: Callable[[tuple[Any, ...]], torch.Tensor],
Expand Down Expand Up @@ -700,21 +701,21 @@ def forward(
dim=self._coil_dim,
)
return out
else:
# Create a single image from the coil data
sense_image = reduce_operator(
coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims),
sensitivity_map=sensitivity_map,
dim=self._coil_dim,
)
# Trasnform the image to the k-space domain
inp = self.forward_operator(sense_image, dim=[d - 1 for d in self._spatial_dims])

# Pass to the transformer
out = self.transformer(inp.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous()
# Otherwise, create a single image from the coil data
sense_image = reduce_operator(
coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims),
sensitivity_map=sensitivity_map,
dim=self._coil_dim,
)
# Trasnform the image to the k-space domain
inp = self.forward_operator(sense_image, dim=[d - 1 for d in self._spatial_dims])

out = self.backward_operator(out, dim=[d - 1 for d in self._spatial_dims])
return out
# Pass to the transformer
out = self.transformer(inp.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous()

out = self.backward_operator(out, dim=[d - 1 for d in self._spatial_dims])
return out


class KSpaceDomainMRIViT3D(nn.Module):
Expand Down
17 changes: 12 additions & 5 deletions direct/nn/transformers/uformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class SepConv2d(torch.nn.Module):
Spacing between kernel elements. Default: 1.
act_layer : torch.nn.Module
Activation layer applied after depthwise convolution. Default: nn.ReLU.
bias : bool
Whether to include a bias term. Default: False.
"""

def __init__(
Expand All @@ -114,6 +116,7 @@ def __init__(
padding: int | tuple[int, int] = 0,
dilation: int | tuple[int, int] = 1,
act_layer: nn.Module = nn.ReLU,
bias: bool = False,
) -> None:
"""Inits :class:`SepConv2d`.
Expand All @@ -133,6 +136,8 @@ def __init__(
Spacing between kernel elements. Default: 1.
act_layer : torch.nn.Module
Activation layer applied after depthwise convolution. Default: nn.ReLU.
bias : bool
Whether to include a bias term. Default: False.
"""
super().__init__()
self.depthwise = torch.nn.Conv2d(
Expand All @@ -143,8 +148,9 @@ def __init__(
padding=padding,
dilation=dilation,
groups=in_channels,
bias=bias,
)
self.pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
self.act_layer = act_layer() if act_layer is not None else nn.Identity()
self.in_channels = in_channels
self.out_channels = out_channels
Expand Down Expand Up @@ -233,13 +239,13 @@ def __init__(
self.heads = heads
pad = (kernel_size - q_stride) // 2
self.to_q = SepConv2d(
in_channels=dim, out_channels=inner_dim, kernel_size=kernel_size, stride=q_stride, padding=pad
in_channels=dim, out_channels=inner_dim, kernel_size=kernel_size, stride=q_stride, padding=pad, bias=bias
)
self.to_k = SepConv2d(
in_channels=dim, out_channels=inner_dim, kernel_size=kernel_size, stride=k_stride, padding=pad
in_channels=dim, out_channels=inner_dim, kernel_size=kernel_size, stride=k_stride, padding=pad, bias=bias
)
self.to_v = SepConv2d(
in_channels=dim, out_channels=inner_dim, kernel_size=kernel_size, stride=v_stride, padding=pad
in_channels=dim, out_channels=inner_dim, kernel_size=kernel_size, stride=v_stride, padding=pad, bias=bias
)

def forward(
Expand Down Expand Up @@ -1253,7 +1259,8 @@ def with_pos_embed(self, tensor: torch.Tensor, pos: Optional[torch.Tensor] = Non
def extra_repr(self) -> str:
return (
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio},modulator={self.modulator}"
f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}, "
f"modulator={self.modulator}"
)

def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
Expand Down
7 changes: 4 additions & 3 deletions direct/nn/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch import nn
from torch.nn import init

__all__ = ["init_weights", "norm", "pad_to_divisible", "pad_to_square", "unnorm", "unpad_to_original", "DropoutPath"]

Expand All @@ -27,7 +27,8 @@ def pad_to_divisible(x: torch.Tensor, pad_size: tuple[int, ...]) -> tuple[torch.
Returns
-------
tuple
Containing the padded tensor and a tuple of tuples indicating the number of pixels padded in each spatial dimension.
Containing the padded tensor and a tuple of tuples indicating the number of pixels padded in
each spatial dimension.
"""
pads = []
for dim, p_dim in zip(x.shape[-len(pad_size) :], pad_size):
Expand Down Expand Up @@ -200,7 +201,7 @@ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
Whether to scale the remaining activations by 1 / (1 - drop_prob) to maintain the expected value of
the activations. Default: True.
"""
super(DropoutPath, self).__init__()
super().__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep

Expand Down
4 changes: 2 additions & 2 deletions direct/nn/transformers/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch import nn
from torch.nn import init

from direct.constants import COMPLEX_SIZE
from direct.nn.transformers.utils import DropoutPath, init_weights, norm, pad_to_divisible, unnorm, unpad_to_original
Expand Down Expand Up @@ -906,7 +906,7 @@ def __init__(

self.norm = nn.LayerNorm(embedding_dim)
# head
self.feature_info = [dict(num_chs=embedding_dim, reduction=0, module="head")]
self.feature_info = [{"num_chs": embedding_dim, "reduction": 0, "module": "head"}]
self.head = nn.Linear(self.num_features, self.out_channels * np.prod(self.patch_size))

self.head.apply(init_weights)
Expand Down

0 comments on commit 3d6272a

Please sign in to comment.