diff --git a/direct/nn/transformers/transformers.py b/direct/nn/transformers/transformers.py index 8faec118..c42af1ad 100644 --- a/direct/nn/transformers/transformers.py +++ b/direct/nn/transformers/transformers.py @@ -164,6 +164,11 @@ def __init__( **kwargs: Other keyword arguments to pass to the parent constructor. """ super().__init__() + for extra_key in kwargs.keys(): + if extra_key not in [ + "model_name", + ]: + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") self.uformer = UFormerModel( patch_size=patch_size, in_channels=COMPLEX_SIZE, @@ -327,6 +332,11 @@ def __init__( Whether to normalize the input tensor. Default: True. """ super().__init__() + for extra_key in kwargs.keys(): + if extra_key not in [ + "model_name", + ]: + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") self.transformer = VisionTransformer2D( average_img_size=average_size, patch_size=patch_size, @@ -374,7 +384,7 @@ def forward(self, masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor) -> return out -class ImageDomainMRIViT3D(VisionTransformer3D): +class ImageDomainMRIViT3D(nn.Module): """Vision Transformer for MRI reconstruction in 3D. Parameters @@ -480,6 +490,11 @@ def __init__( Whether to normalize the input tensor. Default: True. """ super().__init__() + for extra_key in kwargs.keys(): + if extra_key not in [ + "model_name", + ]: + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") self.transformer = VisionTransformer3D( average_img_size=average_size, patch_size=patch_size, @@ -639,6 +654,11 @@ def __init__( Whether to compute the output per coil. """ super().__init__() + for extra_key in kwargs.keys(): + if extra_key not in [ + "model_name", + ]: + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") self.transformer = VisionTransformer2D( average_img_size=average_size, patch_size=patch_size, @@ -829,6 +849,11 @@ def __init__( Whether to compute the output per coil. """ super().__init__() + for extra_key in kwargs.keys(): + if extra_key not in [ + "model_name", + ]: + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") self.transformer = VisionTransformer3D( average_img_size=average_size, patch_size=patch_size, @@ -891,18 +916,18 @@ def forward( dim=self._coil_dim, ) return out - else: - # Create a single image from the coil data - sense_image = reduce_operator( - coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), - sensitivity_map=sensitivity_map, - dim=self._coil_dim, - ) - # Trasnform the image to the k-space domain - inp = self.forward_operator(sense_image, dim=[d - 1 for d in self._spatial_dims]) - # Pass to the transformer - out = self.transformer(inp.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1).contiguous() + # Create a single image from the coil data + sense_image = reduce_operator( + coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ) + # Trasnform the image to the k-space domain + inp = self.forward_operator(sense_image, dim=[d - 1 for d in self._spatial_dims]) - out = self.backward_operator(out, dim=[d - 1 for d in self._spatial_dims]) - return out + # Pass to the transformer + out = self.transformer(inp.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1).contiguous() + + out = self.backward_operator(out, dim=[d - 1 for d in self._spatial_dims]) + return out diff --git a/direct/nn/transformers/uformer.py b/direct/nn/transformers/uformer.py index d24f3bb0..2693544f 100644 --- a/direct/nn/transformers/uformer.py +++ b/direct/nn/transformers/uformer.py @@ -1804,20 +1804,22 @@ def extra_repr(self) -> str: + f"token_mlp={self.mlp},win_size={self.win_size}" ) - def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Performs forward pass of :class:`UFormer`. Parameters ---------- - input : torch.Tensor + x : torch.Tensor + Input tensor. mask : torch.Tensor, optional + Mask tensor. Default: None. Returns ------- torch.Tensor """ # Input Projection - output = self.input_proj(input) + output = self.input_proj(x) output = self.pos_drop(output) # Encoder @@ -1840,8 +1842,8 @@ def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor] = None) -> t # Output Projection output = self.output_proj(output) if self.in_channels != self.out_channels: - input = self.conv_out(input) - return input + output + x = self.conv_out(input) + return x + output class UFormerModel(nn.Module):