Skip to content

Commit

Permalink
Codacy/pylint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jan 7, 2025
1 parent 3d6272a commit 5fe696b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 19 deletions.
53 changes: 39 additions & 14 deletions direct/nn/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ def __init__(
**kwargs: Other keyword arguments to pass to the parent constructor.
"""
super().__init__()
for extra_key in kwargs.keys():
if extra_key not in [
"model_name",
]:
raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")
self.uformer = UFormerModel(
patch_size=patch_size,
in_channels=COMPLEX_SIZE,
Expand Down Expand Up @@ -327,6 +332,11 @@ def __init__(
Whether to normalize the input tensor. Default: True.
"""
super().__init__()
for extra_key in kwargs.keys():
if extra_key not in [
"model_name",
]:
raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")
self.transformer = VisionTransformer2D(
average_img_size=average_size,
patch_size=patch_size,
Expand Down Expand Up @@ -374,7 +384,7 @@ def forward(self, masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor) ->
return out


class ImageDomainMRIViT3D(VisionTransformer3D):
class ImageDomainMRIViT3D(nn.Module):
"""Vision Transformer for MRI reconstruction in 3D.
Parameters
Expand Down Expand Up @@ -480,6 +490,11 @@ def __init__(
Whether to normalize the input tensor. Default: True.
"""
super().__init__()
for extra_key in kwargs.keys():
if extra_key not in [
"model_name",
]:
raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")
self.transformer = VisionTransformer3D(
average_img_size=average_size,
patch_size=patch_size,
Expand Down Expand Up @@ -639,6 +654,11 @@ def __init__(
Whether to compute the output per coil.
"""
super().__init__()
for extra_key in kwargs.keys():
if extra_key not in [
"model_name",
]:
raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")
self.transformer = VisionTransformer2D(
average_img_size=average_size,
patch_size=patch_size,
Expand Down Expand Up @@ -829,6 +849,11 @@ def __init__(
Whether to compute the output per coil.
"""
super().__init__()
for extra_key in kwargs.keys():
if extra_key not in [
"model_name",
]:
raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")
self.transformer = VisionTransformer3D(
average_img_size=average_size,
patch_size=patch_size,
Expand Down Expand Up @@ -891,18 +916,18 @@ def forward(
dim=self._coil_dim,
)
return out
else:
# Create a single image from the coil data
sense_image = reduce_operator(
coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims),
sensitivity_map=sensitivity_map,
dim=self._coil_dim,
)
# Trasnform the image to the k-space domain
inp = self.forward_operator(sense_image, dim=[d - 1 for d in self._spatial_dims])

# Pass to the transformer
out = self.transformer(inp.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1).contiguous()
# Create a single image from the coil data
sense_image = reduce_operator(
coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims),
sensitivity_map=sensitivity_map,
dim=self._coil_dim,
)
# Trasnform the image to the k-space domain
inp = self.forward_operator(sense_image, dim=[d - 1 for d in self._spatial_dims])

out = self.backward_operator(out, dim=[d - 1 for d in self._spatial_dims])
return out
# Pass to the transformer
out = self.transformer(inp.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1).contiguous()

out = self.backward_operator(out, dim=[d - 1 for d in self._spatial_dims])
return out
12 changes: 7 additions & 5 deletions direct/nn/transformers/uformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,20 +1804,22 @@ def extra_repr(self) -> str:
+ f"token_mlp={self.mlp},win_size={self.win_size}"
)

def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Performs forward pass of :class:`UFormer`.
Parameters
----------
input : torch.Tensor
x : torch.Tensor
Input tensor.
mask : torch.Tensor, optional
Mask tensor. Default: None.
Returns
-------
torch.Tensor
"""
# Input Projection
output = self.input_proj(input)
output = self.input_proj(x)
output = self.pos_drop(output)

# Encoder
Expand All @@ -1840,8 +1842,8 @@ def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor] = None) -> t
# Output Projection
output = self.output_proj(output)
if self.in_channels != self.out_channels:
input = self.conv_out(input)
return input + output
x = self.conv_out(input)
return x + output


class UFormerModel(nn.Module):
Expand Down

0 comments on commit 5fe696b

Please sign in to comment.