diff --git a/direct/nn/transformers/__init__.py b/direct/nn/transformers/__init__.py new file mode 100644 index 00000000..c36ca0ba --- /dev/null +++ b/direct/nn/transformers/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) DIRECT Contributors + +"""DIRECT transformers models.""" diff --git a/direct/nn/transformers/config.py b/direct/nn/transformers/config.py new file mode 100644 index 00000000..521c70f7 --- /dev/null +++ b/direct/nn/transformers/config.py @@ -0,0 +1,121 @@ +# Copyright (c) DIRECT Contributors + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +from omegaconf import MISSING + +from direct.config.defaults import ModelConfig +from direct.constants import COMPLEX_SIZE +from direct.nn.transformers.uformer import AttentionTokenProjectionType, LeWinTransformerMLPTokenType + + +@dataclass +class UFormerModelConfig(ModelConfig): + in_channels: int = COMPLEX_SIZE + out_channels: Optional[int] = None + patch_size: int = 256 + embedding_dim: int = 32 + encoder_depths: tuple[int, ...] = (2, 2, 2, 2) + encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8) + bottleneck_depth: int = 2 + bottleneck_num_heads: int = 16 + win_size: int = 8 + mlp_ratio: float = 4.0 + qkv_bias: bool = True + qk_scale: Optional[float] = None + drop_rate: float = 0.0 + attn_drop_rate: float = 0.0 + drop_path_rate: float = 0.1 + patch_norm: bool = True + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.LEFF + shift_flag: bool = True + modulator: bool = False + cross_modulator: bool = False + normalized: bool = True + + +@dataclass +class ImageDomainMRIUFormerConfig(ModelConfig): + patch_size: int = 256 + embedding_dim: int = 32 + encoder_depths: tuple[int, ...] = (2, 2, 2, 2) + encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8) + bottleneck_depth: int = 2 + bottleneck_num_heads: int = 16 + win_size: int = 8 + mlp_ratio: float = 4.0 + qkv_bias: bool = True + qk_scale: Optional[float] = None + drop_rate: float = 0.0 + attn_drop_rate: float = 0.0 + drop_path_rate: float = 0.1 + patch_norm: bool = True + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.LEFF + shift_flag: bool = True + modulator: bool = False + cross_modulator: bool = False + normalized: bool = True + + +@dataclass +class MRIViTConfig(ModelConfig): + embedding_dim: int = 64 + depth: int = 8 + num_heads: int = 9 + mlp_ratio: float = 4.0 + qkv_bias: bool = False + qk_scale: float = None + drop_rate: float = 0.0 + attn_drop_rate: float = 0.0 + dropout_path_rate: float = 0.0 + use_gpsa: bool = True + locality_strength: float = 1.0 + use_pos_embedding: bool = True + normalized: bool = True + + +@dataclass +class VisionTransformer2DConfig(MRIViTConfig): + in_channels: int = COMPLEX_SIZE + out_channels: Optional[int] = None + average_img_size: tuple[int, int] = MISSING + patch_size: tuple[int, int] = (16, 16) + + +@dataclass +class VisionTransformer3DConfig(MRIViTConfig): + in_channels: int = COMPLEX_SIZE + out_channels: Optional[int] = None + average_img_size: tuple[int, int, int] = MISSING + patch_size: tuple[int, int, int] = (16, 16, 16) + + +@dataclass +class ImageDomainMRIViT2DConfig(MRIViTConfig): + average_size: tuple[int, int] = (320, 320) + patch_size: tuple[int, int] = (16, 16) + + +@dataclass +class ImageDomainMRIViT3DConfig(MRIViTConfig): + average_size: tuple[int, int] = (320, 320, 320) + patch_size: tuple[int, int] = (16, 16, 16) + + +@dataclass +class KSpaceDomainMRIViT2DConfig(MRIViTConfig): + average_size: tuple[int, int] = (320, 320) + patch_size: tuple[int, int] = (16, 16) + compute_per_coil: bool = True + + +@dataclass +class KSpaceDomainMRIViT3DConfig(MRIViTConfig): + average_size: tuple[int, int] = (320, 320, 320) + patch_size: tuple[int, int] = (16, 16, 16) + compute_per_coil: bool = True diff --git a/direct/nn/transformers/transformers.py b/direct/nn/transformers/transformers.py new file mode 100644 index 00000000..7ea0d7a5 --- /dev/null +++ b/direct/nn/transformers/transformers.py @@ -0,0 +1,934 @@ +# Copyright (c) DIRECT Contributors + +# pylint: disable=too-many-arguments + +"""DIRECT Vision Transformer models for MRI reconstruction.""" + +from __future__ import annotations + +from typing import Any, Callable, Optional + +import torch +from torch import nn + +from direct.constants import COMPLEX_SIZE +from direct.data.transforms import reduce_operator +from direct.nn.transformers.uformer import AttentionTokenProjectionType, LeWinTransformerMLPTokenType, UFormerModel +from direct.nn.transformers.vit import VisionTransformer2D, VisionTransformer3D + +__all__ = [ + "ImageDomainMRIUFormer", + "ImageDomainMRIViT2D", + "ImageDomainMRIViT3D", + "KSpaceDomainMRIViT2D", + "KSpaceDomainMRIViT3D", +] + + +class ImageDomainMRIUFormer(nn.Module): + """U-Former model for MRI reconstruction in the image domain. + + Parameters + ---------- + forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Forward operator function. + backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Backward operator function. + patch_size : int + Size of the patch. Default: 256. + in_channels : int + Number of input channels. Default: 2. + out_channels : int, optional + Number of output channels. Default: None. + embedding_dim : int + Size of the feature embedding. Default: 32. + encoder_depths : tuple + Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2). + encoder_num_heads : tuple + Number of attention heads for each layer of the encoder of the U-former, from top to bottom. + Default: (1, 2, 4, 8). + bottleneck_depth : int + Default: 16. + bottleneck_num_heads : int + Default: 2. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float + Scale factor for the query and key projection vectors. + If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None. + drop_rate : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop_rate : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path_rate : float + Dropout rate for the stochastic depth regularization. Default: 0.1. + patch_norm : bool + Whether to use normalization for the patch embeddings. Default: True. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF. + shift_flag : bool + Whether to use shift operation in the local attention mechanism. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + normalized : bool + Whether to apply normalization before and denormalization after the forward pass. Default: True. + """ + + def __init__( + self, + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + patch_size: int = 256, + embedding_dim: int = 32, + encoder_depths: tuple[int, ...] = (2, 2, 2, 2), + encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8), + bottleneck_depth: int = 2, + bottleneck_num_heads: int = 16, + win_size: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.1, + patch_norm: bool = True, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.LEFF, + shift_flag: bool = True, + modulator: bool = False, + cross_modulator: bool = False, + normalized: bool = True, + **kwargs, + ) -> None: + """Inits :class:`ImageDomainMRIUFormer`. + + Parameters + ---------- + forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Forward operator function. + backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Backward operator function. + patch_size : int + Size of the patch. Default: 256. + in_channels : int + Number of input channels. Default: 2. + out_channels : int, optional + Number of output channels. Default: None. + embedding_dim : int + Size of the feature embedding. Default: 32. + encoder_depths : tuple + Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2). + encoder_num_heads : tuple + Number of attention heads for each layer of the encoder of the U-former, from top to bottom. + Default: (1, 2, 4, 8). + bottleneck_depth : int + Default: 16. + bottleneck_num_heads : int + Default: 2. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float + Scale factor for the query and key projection vectors. + If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None. + drop_rate : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop_rate : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path_rate : float + Dropout rate for the stochastic depth regularization. Default: 0.1. + patch_norm : bool + Whether to use normalization for the patch embeddings. Default: True. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF. + shift_flag : bool + Whether to use shift operation in the local attention mechanism. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + normalized : bool + Whether to apply normalization before and denormalization after the forward pass. Default: True. + **kwargs: Other keyword arguments to pass to the parent constructor. + """ + super().__init__() + for extra_key in kwargs: + if extra_key not in [ + "model_name", + ]: + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") + self.uformer = UFormerModel( + patch_size=patch_size, + in_channels=COMPLEX_SIZE, + embedding_dim=embedding_dim, + encoder_depths=encoder_depths, + encoder_num_heads=encoder_num_heads, + bottleneck_depth=bottleneck_depth, + bottleneck_num_heads=bottleneck_num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + patch_norm=patch_norm, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + normalized=normalized, + ) + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + def forward(self, masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`ImageDomainMRIUFormer`. + + masked_kspace: torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sensitivity_map: torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2) + + Returns + ------- + out : torch.Tensor + The output tensor of shape (N, height, width, complex=2). + """ + + image = reduce_operator( + coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ).permute(0, 3, 1, 2) + + out = self.uformer(image).permute(0, 2, 3, 1) + + return out + + +class ImageDomainMRIViT2D(nn.Module): + """Vision Transformer for MRI reconstruction in 2D. + + Parameters + ---------- + forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Forward operator function. + backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Backward operator function. + average_size : int or tuple[int, int] + The average size of the input image. If an int is provided, this will be determined by the + `dimensionality`, i.e., (average_size, average_size) for 2D and + (average_size, average_size, average_size) for 3D. Default: 320. + patch_size : int or tuple[int, int] + The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e., + (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa : bool, optional + Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + + def __init__( + self, + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + average_size: int | tuple[int, int] = 320, + patch_size: int | tuple[int, int] = 16, + embedding_dim: int = 64, + depth: int = 8, + num_heads: int = 9, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + use_gpsa: tuple[int, int] = (-1, -1), + locality_strength: float = 1.0, + use_pos_embedding: bool = True, + normalized: bool = True, + **kwargs, + ) -> None: + """Inits :class:`ImageDomainMRIViT2D`. + + Parameters + ---------- + forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Forward operator function. + backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Backward operator function. + average_size : int or tuple[int, int] + The average size of the input image. If an int is provided, this will be determined by the + `dimensionality`, i.e., (average_size, average_size) for 2D and + (average_size, average_size, average_size) for 3D. Default: 320. + patch_size : int or tuple[int, int] + The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e., + (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa : bool, optional + Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + super().__init__() + for extra_key in kwargs: + if extra_key not in [ + "model_name", + ]: + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") + self.transformer = VisionTransformer2D( + average_img_size=average_size, + patch_size=patch_size, + in_channels=COMPLEX_SIZE, + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + dropout_path_rate=dropout_path_rate, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + ) + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + def forward(self, masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`ImageDomainMRIViT2D`. + + masked_kspace: torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sensitivity_map: torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2) + + Returns + ------- + out : torch.Tensor + The output tensor of shape (N, height, width, complex=2). + """ + image = reduce_operator( + coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ).permute(0, 3, 1, 2) + out = self.transformer(image).permute(0, 2, 3, 1) + return out + + +class ImageDomainMRIViT3D(nn.Module): + """Vision Transformer for MRI reconstruction in 3D. + + Parameters + ---------- + forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Forward operator function. + backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Backward operator function. + average_size : int or tuple[int, int, int] + The average size of the input image. If an int is provided, this will be defined as + (average_size, average_size, average_size). Default: 320. + patch_size : int or tuple[int, int, int] + The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size, patch_size). + Default: 16. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa : bool, optional + Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + + def __init__( + self, + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + average_size: int | tuple[int, int, int] = 320, + patch_size: int | tuple[int, int, int] = 16, + embedding_dim: int = 64, + depth: int = 8, + num_heads: int = 9, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + use_gpsa: tuple[int, int] = (-1, -1), + locality_strength: float = 1.0, + use_pos_embedding: bool = True, + normalized: bool = True, + **kwargs, + ) -> None: + """Inits :class:`ImageDomainMRIViT3D`. + + Parameters + ---------- + forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Forward operator function. + backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Backward operator function. + average_size : int or tuple[int, int, int] + The average size of the input image. If an int is provided, this will be defined as + (average_size, average_size, average_size). Default: 320. + patch_size : int or tuple[int, int, int] + The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size, patch_size). + Default: 16. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa : bool, optional + Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + super().__init__() + for extra_key in kwargs: + if extra_key not in [ + "model_name", + ]: + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") + self.transformer = VisionTransformer3D( + average_img_size=average_size, + patch_size=patch_size, + in_channels=COMPLEX_SIZE, + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + dropout_path_rate=dropout_path_rate, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + ) + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (3, 4) + + def forward(self, masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`ImageDomainMRIViT3D`. + + masked_kspace: torch.Tensor + Masked k-space of shape (N, coil, slice/time, height, width, complex=2). + sensitivity_map: torch.Tensor + Sensitivity map of shape (N, coil, slice/time, height, width, complex=2) + + Returns + ------- + out : torch.Tensor + The output tensor of shape (N, slice/time, height, width, complex=2). + """ + + image = reduce_operator( + coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ).permute(0, 4, 1, 2, 3) + out = self.transformer(image).permute(0, 2, 3, 4, 1) + return out + + +class KSpaceDomainMRIViT2D(nn.Module): + """Vision Transformer for MRI reconstruction in 2D in k-space. + + Parameters + ---------- + forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Forward operator function. + backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Backward operator function. + average_size : int or tuple[int, int] + The average size of the input image. If an int is provided, this will be determined by the + `dimensionality`, i.e., (average_size, average_size) for 2D and + (average_size, average_size, average_size) for 3D. Default: 320. + patch_size : int or tuple[int, int] + The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e., + (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa : bool, optional + Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + + def __init__( + self, + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + average_size: int | tuple[int, int] = 320, + patch_size: int | tuple[int, int] = 16, + embedding_dim: int = 64, + depth: int = 8, + num_heads: int = 9, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + use_gpsa: tuple[int, int] = (-1, -1), + locality_strength: float = 1.0, + use_pos_embedding: bool = True, + normalized: bool = True, + compute_per_coil: bool = True, + **kwargs, + ) -> None: + """Inits :class:`KSpaceDomainMRIViT2D`. + + Parameters + ---------- + forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Forward operator function. + backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Backward operator function. + average_size : int or tuple[int, int] + The average size of the input image. If an int is provided, this will be determined by the + `dimensionality`, i.e., (average_size, average_size) for 2D and + (average_size, average_size, average_size) for 3D. Default: 320. + patch_size : int or tuple[int, int] + The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e., + (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa : bool, optional + Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + compute_per_coil : bool + Whether to compute the output per coil. + """ + super().__init__() + for extra_key in kwargs: + if extra_key not in [ + "model_name", + ]: + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") + self.transformer = VisionTransformer2D( + average_img_size=average_size, + patch_size=patch_size, + in_channels=COMPLEX_SIZE, + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + dropout_path_rate=dropout_path_rate, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + ) + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + self.compute_per_coil = compute_per_coil + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + def forward( + self, masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor, sampling_mask: torch.Tensor + ) -> torch.Tensor: + """Forward pass of :class:`KSpaceDomainMRIViT2D`. + + masked_kspace: torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sensitivity_map: torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2) + sampling_mask: torch.Tensor + Sampling mask of shape (N, 1, height, width, 1). + + Returns + ------- + out : torch.Tensor + The output tensor of shape (N, height, width, complex=2). + """ + if self.compute_per_coil: + out = torch.stack( + [ + self.transformer(masked_kspace[:, i].permute(0, 3, 1, 2)) + for i in range(masked_kspace.shape[self._coil_dim]) + ], + dim=self._coil_dim, + ).permute(0, 1, 3, 4, 2) + + out = torch.where(sampling_mask, masked_kspace, out) # data consistency + + # Create a single image from the coil data and return it + out = reduce_operator( + coil_data=self.backward_operator(out, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ) + return out + + # Otherwise, create a single image from the coil data + sense_image = reduce_operator( + coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ) + # Trasnform the image to the k-space domain + inp = self.forward_operator(sense_image, dim=[d - 1 for d in self._spatial_dims]) + + # Pass to the transformer + out = self.transformer(inp.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous() + + out = self.backward_operator(out, dim=[d - 1 for d in self._spatial_dims]) + return out + + +class KSpaceDomainMRIViT3D(nn.Module): + """Vision Transformer for MRI reconstruction in 3D in k-space. + + Parameters + ---------- + forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Forward operator function. + backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Backward operator function. + average_size : int or tuple[int, int] + The average size of the input image. If an int is provided, this will be determined by the + `dimensionality`, i.e., (average_size, average_size) for 2D and + (average_size, average_size, average_size) for 3D. Default: 320. + patch_size : int or tuple[int, int] + The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e., + (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa : bool, optional + Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + + def __init__( + self, + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + average_size: int | tuple[int, int] = 320, + patch_size: int | tuple[int, int] = 16, + embedding_dim: int = 64, + depth: int = 8, + num_heads: int = 9, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + use_gpsa: tuple[int, int] = (-1, -1), + locality_strength: float = 1.0, + use_pos_embedding: bool = True, + normalized: bool = True, + compute_per_coil: bool = True, + **kwargs, + ) -> None: + """Inits :class:`KSpaceDomainMRIViT3D`. + + Parameters + ---------- + forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Forward operator function. + backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Backward operator function. + average_size : int or tuple[int, int] + The average size of the input image. If an int is provided, this will be determined by the + `dimensionality`, i.e., (average_size, average_size) for 2D and + (average_size, average_size, average_size) for 3D. Default: 320. + patch_size : int or tuple[int, int] + The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e., + (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa : bool, optional + Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + compute_per_coil : bool + Whether to compute the output per coil. + """ + super().__init__() + for extra_key in kwargs: + if extra_key not in [ + "model_name", + ]: + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") + self.transformer = VisionTransformer3D( + average_img_size=average_size, + patch_size=patch_size, + in_channels=COMPLEX_SIZE, + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + dropout_path_rate=dropout_path_rate, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + ) + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + self.compute_per_coil = compute_per_coil + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (3, 4) + + def forward( + self, masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor, sampling_mask: torch.Tensor + ) -> torch.Tensor: + """Forward pass of :class:`KSpaceDomainMRIViT3D`. + + masked_kspace: torch.Tensor + Masked k-space of shape (N, coil, slice/time, height, width, complex=2). + sensitivity_map: torch.Tensor + Sensitivity map of shape (N, coil, slice/time, height, width, complex=2) + sampling_mask: torch.Tensor + Sampling mask of shape (N, 1, 1 or slice/time, height, width, 1). + + Returns + ------- + out : torch.Tensor + The output tensor of shape (N, slice/time height, width, complex=2). + """ + if self.compute_per_coil: + out = torch.stack( + [ + self.transformer(masked_kspace[:, i].permute(0, 4, 1, 2, 3)) + for i in range(masked_kspace.shape[self._coil_dim]) + ], + dim=self._coil_dim, + ).permute(0, 1, 3, 4, 5, 2) + + out = torch.where(sampling_mask, masked_kspace, out) # data consistency + + # Create a single image from the coil data and return it + out = reduce_operator( + coil_data=self.backward_operator(out, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ) + return out + + # Create a single image from the coil data + sense_image = reduce_operator( + coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ) + # Trasnform the image to the k-space domain + inp = self.forward_operator(sense_image, dim=[d - 1 for d in self._spatial_dims]) + + # Pass to the transformer + out = self.transformer(inp.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1).contiguous() + + out = self.backward_operator(out, dim=[d - 1 for d in self._spatial_dims]) + return out diff --git a/direct/nn/transformers/transformers_engine.py b/direct/nn/transformers/transformers_engine.py new file mode 100644 index 00000000..891b4806 --- /dev/null +++ b/direct/nn/transformers/transformers_engine.py @@ -0,0 +1,518 @@ +# Copyright (c) DIRECT Contributors + +"""DIRECT MRI transformer-based model engines.""" + +from typing import Any, Callable, Optional + +import torch +from torch import nn + +import direct.data.transforms as T +from direct.config import BaseConfig +from direct.nn.mri_models import MRIModelEngine + + +class ImageDomainMRIViTEngine(MRIModelEngine): + """MRI ViT Model Engine for Image Domain. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + backward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + mixed_precision: bool = False, + **models: nn.Module, + ) -> None: + """Inits :class:`ImageDomainMRIViTEngine`. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + def forward_function(self, data: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]: + """Forward function for :class:`ImageDomainMRIViTEngine`. + + Parameters + ---------- + data : dict[str, Any] + Input data. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + Output image and output k-space. + """ + data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"]) + + output_image = self.model( + masked_kspace=data["masked_kspace"], + sensitivity_map=data["sensitivity_map"], + ) # shape (batch, slice/time, height, width, complex[=2]) + + output_kspace = data["masked_kspace"] + T.apply_mask( + T.apply_padding( + self.forward_operator( + T.expand_operator(output_image, data["sensitivity_map"], dim=self._coil_dim), + dim=self._spatial_dims, + ), + padding=data.get("padding", None), + ), + ~data["sampling_mask"], + return_mask=False, + ) + + return output_image, output_kspace + + +class ImageDomainMRIUFormerEngine(ImageDomainMRIViTEngine): + """MRI U-Former Model Engine for Image Domain. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + backward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + mixed_precision: bool = False, + **models: nn.Module, + ) -> None: + """Inits :class:`ImageDomainMRIUFormerEngine`. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._spatial_dims = (2, 3) + + +class ImageDomainMRIViT2DEngine(ImageDomainMRIViTEngine): + """MRI ViT Model Engine for Image Domain 2D. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + backward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + mixed_precision: bool = False, + **models: nn.Module, + ) -> None: + """Inits :class:`ImageDomainMRIViT2DEngine`. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._spatial_dims = (2, 3) + + +class ImageDomainMRIViT3DEngine(ImageDomainMRIViTEngine): + """MRI ViT Model Engine for Image Domain 3D. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + backward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + mixed_precision: bool = False, + **models: nn.Module, + ) -> None: + """Inits :class:`ImageDomainMRIViT3DEngine`. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._spatial_dims = (3, 4) + + +class KSpaceDomainMRIViTEngine(MRIModelEngine): + """MRI ViT Model Engine for K-Space Domain. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + backward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + mixed_precision: bool = False, + **models: nn.Module, + ) -> None: + """Inits :class:`KSpaceDomainMRIViTEngine`. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + def forward_function(self, data: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]: + """Forward function for :class:`KSpaceDomainMRIViTEngine`. + + Parameters + ---------- + data : dict[str, Any] + Input data. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + Output image and output k-space. + """ + data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"]) + + output_image = self.model( + masked_kspace=data["masked_kspace"], + sensitivity_map=data["sensitivity_map"], + sampling_mask=data["sampling_mask"], + ) # shape (batch, slice/time, height, width, complex[=2]) + + output_kspace = data["masked_kspace"] + T.apply_mask( + T.apply_padding( + self.forward_operator( + T.expand_operator(output_image, data["sensitivity_map"], dim=self._coil_dim), + dim=self._spatial_dims, + ), + padding=data.get("padding", None), + ), + ~data["sampling_mask"], + return_mask=False, + ) + + return output_image, output_kspace + + +class KSpaceDomainMRIViT2DEngine(KSpaceDomainMRIViTEngine): + """MRI ViT Model Engine for K-Space Domain 2D. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + backward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + mixed_precision: bool = False, + **models: nn.Module, + ) -> None: + """Inits :class:`KSpaceDomainMRIViT2DEngine`. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._spatial_dims = (2, 3) + + +class KSpaceDomainMRIViT3DEngine(KSpaceDomainMRIViTEngine): + """MRI ViT Model Engine for K-Space Domain 3D. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + backward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + mixed_precision: bool = False, + **models: nn.Module, + ) -> None: + """Inits :class:`KSpaceDomainMRIViT3DEngine`. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._spatial_dims = (3, 4) diff --git a/direct/nn/transformers/uformer.py b/direct/nn/transformers/uformer.py new file mode 100644 index 00000000..650bbffe --- /dev/null +++ b/direct/nn/transformers/uformer.py @@ -0,0 +1,2037 @@ +# Copyright (c) DIRECT Contributors + +"""U-Former model [1]_ implementation. + +Adapted from [2]_. + +References +---------- +.. [1] Wang, Zhendong, et al. "Uformer: A general u-shaped transformer for image restoration." Proceedings of the + IEEE/CVF conference on computer vision and pattern recognition. 2022. +.. [2] https://github.com/ZhendongWang6/Uformer + +""" + +from __future__ import annotations + +import math +from typing import List, Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import nn +from torch.nn.init import trunc_normal_ + +from direct.nn.transformers.utils import DropoutPath, init_weights, norm, pad_to_square, unnorm, unpad_to_original +from direct.types import DirectEnum + +__all__ = ["AttentionTokenProjectionType", "LeWinTransformerMLPTokenType", "UFormer", "UFormerModel"] + + +class ECALayer1d(nn.Module): + """Efficient Channel Attention (ECA) module for 1D data. + + Parameters + ---------- + channel : int + Number of channels of the input feature map. + k_size : int + Adaptive selection of kernel size. Default: 3. + """ + + def __init__(self, channel: int, k_size: int = 3) -> None: + """Inits :class:`ECALayer1d`. + + Parameters + ---------- + channel : int + Number of channels of the input feature map. + k_size : int + Adaptive selection of kernel size. Default: 3. + """ + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool1d(1) + self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) + self.sigmoid = nn.Sigmoid() + self.channel = channel + self.k_size = k_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Computes the output of the ECA layer. + + Parameters + ---------- + x : torch.Tensor + Input feature map. + + Returns + ------- + y : torch.Tensor + Output of the ECA layer. + """ + # feature descriptor on the global spatial information + y = self.avg_pool(x.transpose(-1, -2)) + + # Two different branches of ECA module + y = self.conv(y.transpose(-1, -2)) + + # Multi-scale information fusion + y = self.sigmoid(y) + + return x * y.expand_as(x) + + +class SepConv2d(torch.nn.Module): + """A 2D Separable Convolutional layer. + + Applies a depthwise convolution followed by a pointwise convolution. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int or tuple of ints + Size of the convolution kernel. + stride : int or tuple of ints + Stride of the convolution. Default: 1. + padding : int or tuple of ints + Padding added to all four sides of the input. Default: 0. + dilation : int or tuple of ints + Spacing between kernel elements. Default: 1. + act_layer : torch.nn.Module + Activation layer applied after depthwise convolution. Default: nn.ReLU. + bias : bool + Whether to include a bias term. Default: False. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int | tuple[int, int] = 1, + padding: int | tuple[int, int] = 0, + dilation: int | tuple[int, int] = 1, + act_layer: nn.Module = nn.ReLU, + bias: bool = False, + ) -> None: + """Inits :class:`SepConv2d`. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int or tuple of ints + Size of the convolution kernel. + stride : int or tuple of ints + Stride of the convolution. Default: 1. + padding : int or tuple of ints + Padding added to all four sides of the input. Default: 0. + dilation : int or tuple of ints + Spacing between kernel elements. Default: 1. + act_layer : torch.nn.Module + Activation layer applied after depthwise convolution. Default: nn.ReLU. + bias : bool + Whether to include a bias term. Default: False. + """ + super().__init__() + self.depthwise = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=bias, + ) + self.pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias) + self.act_layer = act_layer() if act_layer is not None else nn.Identity() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`SepConv2d`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor after applying depthwise and pointwise convolutions with activation. + """ + x = self.depthwise(x) + x = self.act_layer(x) + x = self.pointwise(x) + return x + + +class ConvProjectionModule(nn.Module): + """Convolutional projection layer used in the window attention mechanism. + + The projection layer consists of three convolutional layers for queries, keys, and values. + + Parameters + ---------- + dim : int + Number of channels in the input tensor. + heads : int + Number of heads in multi-head attention. Default: 8. + dim_head : int + Dimension of each head. Default: 64. + kernel_size : int + Size of convolutional kernel. Default: 3. + q_stride : int + Stride of the convolutional kernel for queries. Default: 1. + k_stride : int + Stride of the convolutional kernel for keys. Default: 1. + v_stride : int + Stride of the convolutional kernel for values. Default: 1. + bias : bool + Whether to include a bias term. Default: True. + """ + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + kernel_size: int = 3, + q_stride: int = 1, + k_stride: int = 1, + v_stride: int = 1, + bias: bool = True, + ): + """Inits :class:`ConvProjectionModule`. + + Parameters + ---------- + dim : int + Number of channels in the input tensor. + heads : int + Number of heads in multi-head attention. Default: 8. + dim_head : int + Dimension of each head. Default: 64. + kernel_size : int + Size of convolutional kernel. Default: 3. + q_stride : int + Stride of the convolutional kernel for queries. Default: 1. + k_stride : int + Stride of the convolutional kernel for keys. Default: 1. + v_stride : int + Stride of the convolutional kernel for values. Default: 1. + bias : bool + Whether to include a bias term. Default: True. + """ + super().__init__() + + inner_dim = dim_head * heads + self.heads = heads + pad = (kernel_size - q_stride) // 2 + self.to_q = SepConv2d( + in_channels=dim, out_channels=inner_dim, kernel_size=kernel_size, stride=q_stride, padding=pad, bias=bias + ) + self.to_k = SepConv2d( + in_channels=dim, out_channels=inner_dim, kernel_size=kernel_size, stride=k_stride, padding=pad, bias=bias + ) + self.to_v = SepConv2d( + in_channels=dim, out_channels=inner_dim, kernel_size=kernel_size, stride=v_stride, padding=pad, bias=bias + ) + + def forward( + self, x: torch.Tensor, attn_kv: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass of :class:`ConvProjectionModule`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + attn_kv : torch.Tensor, optional + Attention key/value tensor. Default None. + + Returns + ------- + q : torch.Tensor + Query tensor. + k : torch.Tensor + Key tensor. + v : torch.Tensor + Value tensor. + """ + _, n, _, h = *x.shape, self.heads + f = int(math.sqrt(n)) + w = int(math.sqrt(n)) + + attn_kv = x if attn_kv is None else attn_kv + x = rearrange(x, "b (f w) c -> b c f w", f=f, w=w) + attn_kv = rearrange(attn_kv, "b (f w) c -> b c f w", f=f, w=w) + q = self.to_q(x) + q = rearrange(q, "b (h d) f w -> b h (f w) d", h=h) + + k = self.to_k(attn_kv) + v = self.to_v(attn_kv) + k = rearrange(k, "b (h d) f w -> b h (f w) d", h=h) + v = rearrange(v, "b (h d) f w -> b h (f w) d", h=h) + return q, k, v + + +class LinearProjectionModule(nn.Module): + """Linear projection layer used in the window attention mechanism. + + Parameters + ---------- + dim : int + The input feature dimension. + heads : int + The number of heads in the multi-head attention mechanism. Default: 8. + dim_head : int, optional + The feature dimension of each head. Default: 64. + bias : bool, optional + Whether to use bias in the linear projections. Default: True. + """ + + def __init__(self, dim: int, heads: int = 8, dim_head: int = 64, bias: bool = True) -> None: + """Inits :class:LinearProjectionModule`. + + Parameters + ---------- + dim : int + The input feature dimension. + heads : int + The number of heads in the multi-head attention mechanism. Default: 8. + dim_head : int, optional + The feature dimension of each head. Default: 64. + bias : bool, optional + Whether to use bias in the linear projections. Default: True. + """ + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.to_q = nn.Linear(dim, inner_dim, bias=bias) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias) + self.dim = dim + self.inner_dim = inner_dim + + def forward( + self, x: torch.Tensor, attn_kv: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Performs forward pass of :class:`LinearProjectionModule`. + + Parameters + ---------- + x : torch.Tensor of shape (batch_size, seq_length, dim) + The input tensor. + attn_kv : torch.Tensor of shape (batch_size, seq_length, dim), optional + The tensor to be used for computing the attention scores. If None, the input tensor is used. Default: None. + + Returns + ------- + q : torch.Tensor of shape (batch_size, seq_length, heads, dim_head) + The tensor resulting from the linear projection of x used for computing the queries. + k : torch.Tensor of shape (batch_size, seq_length, heads, dim_head) + The tensor resulting from the linear projection of attn_kv used for computing the keys. + v : torch.Tensor of shape (batch_size, seq_length, heads, dim_head) + The tensor resulting from the linear projection of attn_kv used for computing the values. + + """ + B_, N, C = x.shape + if attn_kv is not None: + attn_kv = attn_kv.unsqueeze(0).repeat(B_, 1, 1) + else: + attn_kv = x + N_kv = attn_kv.size(1) + q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) + kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) + q = q[0] + k, v = kv[0], kv[1] + return q, k, v + + +class AttentionTokenProjectionType(DirectEnum): + CONV = "conv" + LINEAR = "linear" + + +class WindowAttentionModule(nn.Module): + """A window-based multi-head self-attention module. + + Parameters + ---------- + dim : int + Input feature dimension. + win_size : tuple[int, int] + The window size (height and width). + num_heads : int + Number of heads for multi-head self-attention. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR + or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR. + qkv_bias : bool + Whether to use bias in the linear projection layer for queries, keys, and values. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout rate for attention weights. + proj_drop : float + Dropout rate for the output of the last linear projection layer. + """ + + def __init__( + self, + dim: int, + win_size: tuple[int, int], + num_heads: int, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ) -> None: + """Inits :class:`WindowAttentionModule`. + + Parameters + ---------- + dim : int + Input feature dimension. + win_size : tuple[int, int] + The window size (height and width). + num_heads : int + Number of heads for multi-head self-attention. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR + or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR. + qkv_bias : bool + Whether to use bias in the linear projection layer for queries, keys, and values. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout rate for attention weights. + proj_drop : float + Dropout rate for the output of the last linear projection layer. + """ + # pylint: disable=too-many-locals + super().__init__() + self.dim = dim + self.win_size = win_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1] + coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1] + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.win_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + trunc_normal_(self.relative_position_bias_table, std=0.02) + + if token_projection == AttentionTokenProjectionType.CONV: + self.qkv = ConvProjectionModule(dim, num_heads, dim // num_heads, bias=qkv_bias) + else: + self.qkv = LinearProjectionModule(dim, num_heads, dim // num_heads, bias=qkv_bias) + + self.token_projection = token_projection + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward( + self, x: torch.Tensor, attn_kv: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Performs forward pass of :class:`WindowAttentionModule`. + + Parameters + ---------- + x : torch.Tensor + A tensor of shape `(B, N, C)` representing the input features, where `B` is the batch size, `N` is the + sequence length, and `C` is the input feature dimension. + attn_kv : torch.Tensor, optional + An optional tensor of shape `(B, N, C)` representing the key-value pairs used for attention computation. + If `None`, the key-value pairs are computed from `x` itself. Default: None. + mask : torch.Tensor, optional + An optional tensor of shape representing the binary mask for the input sequence. + If `None`, no masking is applied. Default: None. + + Returns + ------- + torch.Tensor + A tensor of shape `(B, N, C)` representing the output features after attention computation. + """ + B_, N, C = x.shape + q, k, v = self.qkv(x, attn_kv) + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + ratio = attn.size(-1) // relative_position_bias.size(-1) + relative_position_bias = repeat(relative_position_bias, "nH l c -> nH l (c d)", d=ratio) + + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + mask = repeat(mask, "nW m n -> nW m (n d)", d=ratio) + attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N * ratio) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}" + + +class AttentionModule(nn.Module): + """Self-attention module. + + Parameters + ---------- + dim : int + The input feature dimension. + num_heads : int + The number of attention heads. + qkv_bias : bool + Whether to include biases in the query, key, and value projections. Default: True. + qk_scale : float, optional + Scaling factor for the query and key projections. Default: None. + attn_drop : float + Dropout probability for the attention weights. Default: 0.0. + proj_drop : float + Dropout probability for the output of the attention module. Default: 0.0. + """ + + def __init__( + self, + dim: int, + num_heads: int, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + """Inits :class:`AttentionModule`. + + Parameters + ---------- + dim : int + The input feature dimension. + num_heads : int + The number of attention heads. + qkv_bias : bool + Whether to include biases in the query, key, and value projections. Default: True. + qk_scale : float, optional + Scaling factor for the query and key projections. Default: None. + attn_drop : float + Dropout probability for the attention weights. Default: 0.0. + proj_drop : float + Dropout probability for the output of the attention module. Default: 0.0. + """ + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = LinearProjectionModule(dim, num_heads, dim // num_heads, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward( + self, x: torch.Tensor, attn_kv: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Performs the forward pass of :class:`AttentionModule`. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + attn_kv : torch.Tensor, optional + The attention key/value tensor. + mask : torch.Tensor, optional + + Returns + ------- + torch.Tensor + """ + B_, N, C = x.shape + q, k, v = self.qkv(x, attn_kv) + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}" + + +class MLP(nn.Module): + """Multi-layer perceptron with optional dropout regularization. + + Parameters + ---------- + in_features : int + Number of input features. + hidden_features : int, optional + Number of output features in the hidden layer. If not specified, `in_features` is used. + out_features : int, optional + Number of output features. If not specified, `in_features` is used. + act_layer : nn.Module + Activation layer. Default: GeLU. + drop : float + Dropout probability. Default: 0.0. + """ + + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, + ) -> None: + """Inits :class:`MLP`. + + Parameters + ---------- + in_features : int + Number of input features. + hidden_features : int, optional + Number of output features in the hidden layer. If not specified, `in_features` is used. + out_features : int, optional + Number of output features. If not specified, `in_features` is used. + act_layer : nn.Module + Activation layer. Default: GeLU. + drop : float + Dropout probability. Default: 0.0. + """ + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.in_features = in_features + self.hidden_features = hidden_features + self.out_features = out_features + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the :class:`MLP`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + output : torch.Tensor + """ + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class LeFF(nn.Module): + """Locally-enhanced Feed-Forward Network module. + + Parameters + ---------- + dim : int + Dimension of the input and output features. Default: 32. + hidden_dim : int + Dimension of the hidden features. Default: 128. + act_layer : nn.Module + Activation layer to apply after the first linear layer and the depthwise convolution. Default: GELU. + use_eca : bool + If True, adds a 1D ECA layer after the second linear layer. Default: False. + """ + + def __init__( + self, dim: int = 32, hidden_dim: int = 128, act_layer: nn.Module = nn.GELU, use_eca: bool = False + ) -> None: + """Inits :class:`LeFF`. + + Parameters + ---------- + dim : int + Dimension of the input and output features. Default: 32. + hidden_dim : int + Dimension of the hidden features. Default: 128. + act_layer : nn.Module + Activation layer to apply after the first linear layer and the depthwise convolution. Default: GELU. + use_eca : bool + If True, adds a 1D ECA layer after the second linear layer. Default: False. + """ + super().__init__() + self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim), act_layer()) + self.dwconv = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=3, stride=1, padding=1), act_layer() + ) + self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim)) + self.dim = dim + self.hidden_dim = hidden_dim + self.eca = ECALayer1d(dim) if use_eca else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`LeFF`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + """ + # bs x hw x c + _, hw, _ = x.size() + hh = int(math.sqrt(hw)) + + x = self.linear1(x) + + # spatial restore + x = rearrange(x, " b (h w) (c) -> b c h w ", h=hh, w=hh) + # bs,hidden_dim,32x32 + + x = self.dwconv(x) + + # flatten + x = rearrange(x, " b c h w -> b (h w) c", h=hh, w=hh) + + x = self.linear2(x) + x = self.eca(x) + + return x + + +def window_partition(x: torch.Tensor, win_size: int, dilation_rate: int = 1) -> torch.Tensor: + """Partition the input tensor into windows of specified size. + + Parameters + ---------- + x : torch.Tensor + The input tensor to be partitioned into windows. + win_size : int + The size of the square windows to partition the tensor into. + dilation_rate : int + The dilation rate for convolution. Default: 1. + + Returns + ------- + windows : torch.Tensor + The tensor representing windows partitioned from input tensor. + """ + B, H, W, C = x.shape + if dilation_rate != 1: + x = x.permute(0, 3, 1, 2) # B, C, H, W + x = F.unfold( + x, kernel_size=win_size, dilation=dilation_rate, padding=4 * (dilation_rate - 1), stride=win_size + ) # B, C*Wh*Ww, H/Wh*W/Ww + windows = x.permute(0, 2, 1).contiguous().view(-1, C, win_size, win_size) # B' ,C ,Wh ,Ww + windows = windows.permute(0, 2, 3, 1).contiguous() # B' ,Wh ,Ww ,C + else: + x = x.view(B, H // win_size, win_size, W // win_size, win_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) # B' ,Wh ,Ww ,C + return windows + + +def window_reverse(windows: torch.Tensor, win_size: int, H: int, W: int, dilation_rate: int = 1) -> torch.Tensor: + """Rearrange the partitioned tensor back to the original tensor. + + Parameters + ---------- + windows : torch.Tensor + The tensor representing windows partitioned from input tensor. + win_size : int + The size of the square windows used to partition the tensor. + H : int + The height of the original tensor before partitioning. + W : int + The width of the original tensor before partitioning. + dilation_rate : int + The dilation rate for convolution. Default 1. + + Returns + ------- + x: torch.Tensor + The original tensor rearranged from the partitioned tensor. + + """ + # B' ,Wh ,Ww ,C + B = int(windows.shape[0] / (H * W / win_size / win_size)) + x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1) + if dilation_rate != 1: + x = windows.permute(0, 5, 3, 4, 1, 2).contiguous() # B, C*Wh*Ww, H/Wh*W/Ww + x = F.fold( + x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4 * (dilation_rate - 1), stride=win_size + ) + else: + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class DownSampleBlock(nn.Module): + """Convolution based downsample block. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the convolution. + """ + + def __init__(self, in_channels: int, out_channels: int) -> None: + """Inits :class:`DownSampleBlock`. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the convolution. + """ + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1), + ) + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`DownSampleBlock`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Downsampled output. + """ + B, L, C = x.shape + H = int(math.sqrt(L)) + W = int(math.sqrt(L)) + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + out = self.conv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C + return out + + +class UpSampleBlock(nn.Module): + """Convolution based upsample block. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the convolution. + """ + + def __init__(self, in_channels: int, out_channels: int) -> None: + """Inits :class:`UpSampleBlock`. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the convolution. + """ + super().__init__() + self.deconv = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + ) + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`UpSampleBlock`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Upsampled output. + """ + B, L, C = x.shape + H = int(math.sqrt(L)) + W = int(math.sqrt(L)) + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + out = self.deconv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C + return out + + +class InputProjection(nn.Module): + """Input convolutional projection used in the U-Former model. + + Parameters + ---------- + in_channels : int + Number of input channels. Default: 3. + out_channels : int + Number of output channels after the projection. Default: 64. + kernel_size : int or tuple of ints + Convolution kernel size. Default: 3. + stride : int or tuple of ints + Stride of the convolution. Default: 1. + norm_layer : nn.Module, optional + Normalization layer to apply after the projection. Default: None. + act_layer : nn.Module + Activation layer to apply after the projection. Default: nn.LeakyReLU. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 64, + kernel_size: int | tuple[int, int] = 3, + stride: int | tuple[int, int] = 1, + norm_layer: Optional[nn.Module] = None, + act_layer: nn.Module = nn.LeakyReLU, + ) -> None: + """Inits :class:`InputProjection`. + + Parameters + ---------- + in_channels : int + Number of input channels. Default: 3. + out_channels : int + Number of output channels after the projection. Default: 64. + kernel_size : int or tuple of ints + Convolution kernel size. Default: 3. + stride : int or tuple of ints + Stride of the convolution. Default: 1. + norm_layer : nn.Module, optional + Normalization layer to apply after the projection. Default: None. + act_layer : nn.Module + Activation layer to apply after the projection. Default: nn.LeakyReLU. + """ + super().__init__() + self.proj = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=kernel_size // 2), + act_layer(inplace=True), + ) + if norm_layer is not None: + self.norm = norm_layer(out_channels) + else: + self.norm = None + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`InputProjection`. + + Parameters + ---------- + x : torch.Tensor + + Returns + ------- + torch.Tensor + """ + x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B H*W C + if self.norm is not None: + x = self.norm(x) + return x + + +class OutputProjection(nn.Module): + """Output convolutional projection used in the U-Former model. + + Parameters + ---------- + in_channels : int + Number of input channels. Default: 64. + out_channels : int + Number of output channels after the projection. Default: 3. + kernel_size : int or tuple of ints + Convolution kernel size. Default: 3. + stride : int or tuple of ints + Stride of the convolution. Default: 1. + norm_layer : nn.Module, optional + Normalization layer to apply after the projection. Default: None. + act_layer : nn.Module, optional + Activation layer to apply after the projection. Default: None. + """ + + def __init__( + self, + in_channels: int = 64, + out_channels: int = 3, + kernel_size: int | tuple[int, int] = 3, + stride: int | tuple[int, int] = 1, + norm_layer: Optional[nn.Module] = None, + act_layer: Optional[nn.Module] = None, + ): + """Inits :class:`InputProjection`. + + Parameters + ---------- + in_channels : int + Number of input channels. Default: 64. + out_channels : int + Number of output channels after the projection. Default: 3. + kernel_size : int or tuple of ints + Convolution kernel size. Default: 3. + stride : int or tuple of ints + Stride of the convolution. Default: 1. + norm_layer : nn.Module, optional + Normalization layer to apply after the projection. Default: None. + act_layer : nn.Module, optional + Activation layer to apply after the projection. Default: None. + """ + super().__init__() + self.proj = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=kernel_size // 2), + ) + if act_layer is not None: + self.proj.add_module("activation", act_layer(inplace=True)) + if norm_layer is not None: + self.norm = norm_layer(out_channels) + else: + self.norm = None + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`OutputProjection`. + + Parameters + ---------- + x : torch.Tensor + + Returns + ------- + torch.Tensor + """ + B, L, C = x.shape + H = int(math.sqrt(L)) + W = int(math.sqrt(L)) + x = x.transpose(1, 2).view(B, C, H, W) + x = self.proj(x) + if self.norm is not None: + x = self.norm(x) + return x + + +class LeWinTransformerMLPTokenType(DirectEnum): + MLP = "mlp" + FFN = "ffn" + LEFF = "leff" + + +class LeWinTransformerBlock(nn.Module): + """Applies a window-based multi-head self-attention and MLP or LeFF on the input tensor. + + Parameters + ---------- + dim : int + Number of input channels. + input_resolution : tuple of ints + Input resolution. + num_heads : int + Number of attention heads. + win_size : int + Window size for the attention mechanism. Default: 8. + shift_size : int + The number of pixels to shift the window. Default: 0. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float, optional + Scale factor for the query and key projection vectors. + If set to None, will use the default value of :math`1 / \sqrt(dim)`. Default: None. + drop : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path : float + Dropout rate for the stochastic depth regularization. Default: 0.0. + act_layer : nn.Module + The activation function to use. Default: nn.GELU. + norm_layer : nn.Module + The normalization layer to use. Default: nn.LayerNorm. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR + or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or + LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + """ + + def __init__( + self, + dim: int, + input_resolution: tuple[int, int], + num_heads: int, + win_size: int = 8, + shift_size: int = 0, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.LEFF, + modulator: bool = False, + cross_modulator: bool = False, + ) -> None: + r"""Inits :class:`LeWinTransformerBlock`. + + Parameters + ---------- + dim : int + Number of input channels. + input_resolution : tuple of ints + Input resolution. + num_heads : int + Number of attention heads. + win_size : int + Window size for the attention mechanism. Default: 8. + shift_size : int + The number of pixels to shift the window. Default: 0. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float, optional + Scale factor for the query and key projection vectors. + If set to None, will use the default value of :math`1 / \sqrt(dim)`. Default: None. + drop : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path : float + Dropout rate for the stochastic depth regularization. Default: 0.0. + act_layer : nn.Module + The activation function to use. Default: nn.GELU. + norm_layer : nn.Module + The normalization layer to use. Default: nn.LayerNorm. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR + or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or + LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + """ + # pylint: disable=too-many-locals + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.win_size = win_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.token_mlp = token_mlp + if min(self.input_resolution) <= self.win_size: + self.shift_size = 0 + self.win_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size" + + if modulator: + self.modulator = nn.Embedding(win_size * win_size, dim) # modulator + else: + self.modulator = None + + if cross_modulator: + self.cross_modulator = nn.Embedding(win_size * win_size, dim) # cross_modulator + self.cross_attn = AttentionModule( + dim, + num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.norm_cross = norm_layer(dim) + else: + self.cross_modulator = None + + self.norm1 = norm_layer(dim) + self.attn = WindowAttentionModule( + dim, + win_size=(self.win_size, self.win_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + token_projection=token_projection, + ) + + self.drop_path = DropoutPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + if token_mlp == LeWinTransformerMLPTokenType.MLP: + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + else: + self.mlp = LeFF(dim, mlp_hidden_dim, act_layer=act_layer) + + def with_pos_embed(self, tensor: torch.Tensor, pos: Optional[torch.Tensor] = None) -> torch.Tensor: + """Add positional embeddings to the input tensor. + + Parameters + ---------- + tensor : torch.Tensor + The input tensor. + pos : torch.Tensor, optional + The positional embeddings to add to the input tensor. Default: None. + + Returns + ------- + torch.Tensor + """ + return tensor if pos is None else tensor + pos + + def extra_repr(self) -> str: + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}, " + f"modulator={self.modulator}" + ) + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Performs the forward pass of :class:`LeWinTransformerBlock`. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + mask : torch.Tensor, optional + The mask tensor indicating which elements should be ignored. Default: None. + + Returns + ------- + torch.Tensor + """ + # pylint: disable=too-many-locals + B, L, C = x.shape + H = int(math.sqrt(L)) + W = int(math.sqrt(L)) + + ## input mask + if mask is not None: + input_mask = F.interpolate(mask, size=(H, W)).permute(0, 2, 3, 1) + input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1 + attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size + attn_mask = attn_mask.unsqueeze(2) * attn_mask.unsqueeze(1) # nW, win_size*win_size, win_size*win_size + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + ## shift mask + if self.shift_size > 0: + # calculate attention mask for SW-MSA + shift_mask = torch.zeros((1, H, W, 1)).type_as(x) + h_slices = ( + slice(0, -self.win_size), + slice(-self.win_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.win_size), + slice(-self.win_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + shift_mask[:, h, w, :] = cnt + cnt += 1 + shift_mask_windows = window_partition(shift_mask, self.win_size) # nW, win_size, win_size, 1 + shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size + shift_attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze( + 2 + ) # nW, win_size*win_size, win_size*win_size + shift_attn_mask = shift_attn_mask.masked_fill(shift_attn_mask != 0, float(-100.0)).masked_fill( + shift_attn_mask == 0, float(0.0) + ) + attn_mask = attn_mask + shift_attn_mask if attn_mask is not None else shift_attn_mask + if self.cross_modulator is not None: + shortcut = x + x_cross = self.norm_cross(x) + x_cross = self.cross_attn(x, self.cross_modulator.weight) + x = shortcut + x_cross + shortcut = x + + x = self.norm1(x) + x = x.view(B, H, W, C) + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + # partition windows + x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C N*C->C + x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C + # with_modulator + if self.modulator is not None: + wmsa_in = self.with_pos_embed(x_windows, self.modulator.weight) + else: + wmsa_in = x_windows + + # W-MSA/SW-MSA + attn_windows = self.attn(wmsa_in, mask=attn_mask) # nW*B, win_size*win_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C) + shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + del attn_mask + return x + + +class BasicUFormerLayer(nn.Module): + """Basic layer of U-Former. + + Parameters + ---------- + dim : int + Number of input channels. + input_resolution : tuple of ints + Input resolution. + num_heads : int + Number of attention heads. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float, optional + Scale factor for the query and key projection vectors. + If set to None, will use the default value of :math`1 / \sqrt(dim)`. Default: None. + drop : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path : float + Dropout rate for the stochastic depth regularization. Default: 0.0. + norm_layer : nn.Module + The normalization layer to use. Default: nn.LayerNorm. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR + or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or + LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF. + shift_flag : bool + Whether to use shift in the attention sliding windows or not. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + """ + + def __init__( + self, + dim: int, + input_resolution: tuple[int, int], + depth: int, + num_heads: int, + win_size: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[bool] = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: List[float] | float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.MLP, + shift_flag: bool = True, + modulator: bool = False, + cross_modulator: bool = False, + ) -> None: + r"""Inits :class:`BasicUFormerLayer`. + + Parameters + ---------- + dim : int + Number of input channels. + input_resolution : tuple of ints + Input resolution. + num_heads : int + Number of attention heads. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float, optional + Scale factor for the query and key projection vectors. + If set to None, will use the default value of :math`1 / \sqrt(dim)`. Default: None. + drop : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path : float + Dropout rate for the stochastic depth regularization. Default: 0.0. + norm_layer : nn.Module + The normalization layer to use. Default: nn.LayerNorm. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR + or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or + LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF. + shift_flag : bool + Whether to use shift in the attention sliding windows or not. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + """ + # pylint: disable=too-many-locals + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + + # build blocks + self.blocks = nn.ModuleList( + [ + LeWinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + win_size=win_size, + shift_size=(0 if (i % 2 == 0) else win_size // 2) if shift_flag else 0, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + token_projection=token_projection, + token_mlp=token_mlp, + modulator=modulator, + cross_modulator=cross_modulator, + ) + for i in range(depth) + ] + ) + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Performs forward pass of :class:`BasicUFormerLayer`. + + Parameters + ---------- + x : torch.Tensor + mask : torch.Tensor, optional + + Returns + ------- + torch.Tensor + """ + for blk in self.blocks: + x = blk(x, mask) + return x + + +class UFormer(nn.Module): + """U-Former model based on [1]_, code originally implemented in [2]_. + + Parameters + ---------- + patch_size : int + Size of the patch. Default: 256. + in_channels : int + Number of input channels. Default: 2. + out_channels : int, optional + Number of output channels. Default: None. + embedding_dim : int + Size of the feature embedding. Default: 32. + encoder_depths : tuple + Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2). + encoder_num_heads : tuple + Number of attention heads for each layer of the encoder of the U-former, from top to bottom. + Default: (1, 2, 4, 8). + bottleneck_depth : int + Default: 16. + bottleneck_num_heads : int + Default: 2. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float + Scale factor for the query and key projection vectors. + If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None. + drop_rate : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop_rate : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path_rate : float + Dropout rate for the stochastic depth regularization. Default: 0.1. + patch_norm : bool + Whether to use normalization for the patch embeddings. Default: True. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR + or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or + LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF. + shift_flag : bool + Whether to use shift operation in the local attention mechanism. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + **kwargs: Other keyword arguments to pass to the parent constructor. + + References + ---------- + .. [1] Wang, Zhendong, et al. "Uformer: A general u-shaped transformer for image restoration." Proceedings of the + IEEE/CVF conference on computer vision and pattern recognition. 2022. + .. [2] https://github.com/ZhendongWang6/Uformer + """ + + def __init__( + self, + patch_size: int = 256, + in_channels: int = 2, + out_channels: Optional[int] = None, + embedding_dim: int = 32, + encoder_depths: tuple[int, ...] = (2, 2, 2, 2), + encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8), + bottleneck_depth: int = 2, + bottleneck_num_heads: int = 16, + win_size: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.1, + patch_norm: bool = True, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.LEFF, + shift_flag: bool = True, + modulator: bool = False, + cross_modulator: bool = False, + ) -> None: + """Inits :class:`UFormer`. + + Parameters + ---------- + patch_size : int + Size of the patch. Default: 256. + in_channels : int + Number of input channels. Default: 2. + out_channels : int, optional + Number of output channels. Default: None. + embedding_dim : int + Size of the feature embedding. Default: 32. + encoder_depths : tuple + Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2). + encoder_num_heads : tuple + Number of attention heads for each layer of the encoder of the U-former, from top to bottom. + Default: (1, 2, 4, 8). + bottleneck_depth : int + Default: 16. + bottleneck_num_heads : int + Default: 2. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float + Scale factor for the query and key projection vectors. + If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None. + drop_rate : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop_rate : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path_rate : float + Dropout rate for the stochastic depth regularization. Default: 0.1. + patch_norm : bool + Whether to use normalization for the patch embeddings. Default: True. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR + or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or + LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF. + shift_flag : bool + Whether to use shift operation in the local attention mechanism. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + **kwargs: Other keyword arguments to pass to the parent constructor. + """ + # pylint: disable=too-many-locals + super().__init__() + if len(encoder_num_heads) != len(encoder_depths): + raise ValueError( + f"The number of heads for each layer should be the same as the number of layers. " + f"Got {len(encoder_num_heads)} for {len(encoder_depths)} layers." + ) + if patch_size < (2 ** len(encoder_depths) * win_size): + raise ValueError( + f"Patch size must be greater or equal than 2 ** number of scales * window size." + f" Received: patch_size={patch_size}, number of scales=={len(encoder_depths)}," + f" and window_size={win_size}." + ) + self.num_enc_layers = len(encoder_num_heads) + self.num_dec_layers = len(encoder_num_heads) + depths = (*encoder_depths, bottleneck_depth, *encoder_depths[::-1]) + num_heads = (*encoder_num_heads, bottleneck_num_heads, bottleneck_num_heads, *encoder_num_heads[::-1][:-1]) + self.embedding_dim = embedding_dim + self.patch_norm = patch_norm + self.mlp_ratio = mlp_ratio + self.token_projection = token_projection + self.mlp = token_mlp + self.win_size = win_size + self.reso = patch_size + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[: self.num_enc_layers]))] + conv_dpr = [drop_path_rate] * depths[self.num_enc_layers + 1] + dec_dpr = enc_dpr[::-1] + + # Build layers + + # Input + self.input_proj = InputProjection( + in_channels=in_channels, out_channels=embedding_dim, kernel_size=3, stride=1, act_layer=nn.LeakyReLU + ) + out_channels = out_channels if out_channels else in_channels + # Output + self.output_proj = OutputProjection( + in_channels=2 * embedding_dim, out_channels=out_channels, kernel_size=3, stride=1 + ) + if in_channels != out_channels: + self.conv_out = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0) + self.in_channels = in_channels + self.out_channels = out_channels + + # Encoder + self.encoder_layers = nn.ModuleList() + self.downsamples = nn.ModuleList() + for i in range(self.num_enc_layers): + layer_name = f"encoderlayer_{i}" + layer_input_resolution = (patch_size // (2**i), patch_size // (2**i)) + layer_dim = embedding_dim * (2**i) + layer_depth = depths[i] + layer_drop_path = enc_dpr[sum(depths[:i]) : sum(depths[: i + 1])] + layer = BasicUFormerLayer( + dim=layer_dim, + input_resolution=layer_input_resolution, + depth=layer_depth, + num_heads=num_heads[i], + win_size=win_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=layer_drop_path, + norm_layer=nn.LayerNorm, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + ) + self.encoder_layers.add_module(layer_name, layer) + + downsample_layer_name = f"downsample_{i}" + downsample_layer = DownSampleBlock(layer_dim, embedding_dim * (2 ** (i + 1))) + self.downsamples.add_module(downsample_layer_name, downsample_layer) + # Bottleneck + self.bottleneck = BasicUFormerLayer( + dim=embedding_dim * (2**self.num_enc_layers), + input_resolution=(patch_size // (2**self.num_enc_layers), patch_size // (2**self.num_enc_layers)), + depth=depths[self.num_enc_layers], + num_heads=num_heads[self.num_enc_layers], + win_size=win_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=conv_dpr, + norm_layer=nn.LayerNorm, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + ) + # Decoder + self.upsamples = nn.ModuleList() + self.decoder_layers = nn.ModuleList() + for i in range(self.num_dec_layers, 0, -1): + upsample_layer_name = f"upsample_{self.num_dec_layers - i}" + if i == self.num_dec_layers: + upsample_in_channels = embedding_dim * (2**i) + else: + upsample_in_channels = embedding_dim * (2 ** (i + 1)) + upsample_out_channels = embedding_dim * (2 ** (i - 1)) + upsample_layer = UpSampleBlock(upsample_in_channels, upsample_out_channels) + self.upsamples.add_module(upsample_layer_name, upsample_layer) + + layer_name = f"decoderlayer_{self.num_dec_layers - i}" + layer_input_resolution = (patch_size // (2 ** (i - 1)), patch_size // (2 ** (i - 1))) + layer_dim = embedding_dim * (2**i) + layer_num = self.num_enc_layers + self.num_dec_layers - i + 1 + layer_depth = depths[layer_num] + if i == self.num_dec_layers: + layer_drop_path = dec_dpr[: depths[layer_num]] + else: + start = self.num_enc_layers + 1 + layer_drop_path = dec_dpr[sum(depths[start:layer_num]) : sum(depths[start : layer_num + 1])] + layer = BasicUFormerLayer( + dim=layer_dim, + input_resolution=layer_input_resolution, + depth=layer_depth, + num_heads=num_heads[layer_num], + win_size=win_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=layer_drop_path, + norm_layer=nn.LayerNorm, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + self.decoder_layers.add_module(layer_name, layer) + + self.apply(init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {"absolute_pos_embed"} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"relative_position_bias_table"} + + def extra_repr(self) -> str: + return ( + f"embedding_dim={self.embedding_dim}, token_projection={self.token_projection}, " + + f"token_mlp={self.mlp},win_size={self.win_size}" + ) + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Performs forward pass of :class:`UFormer`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + mask : torch.Tensor, optional + Mask tensor. Default: None. + + Returns + ------- + torch.Tensor + """ + # Input Projection + output = self.input_proj(x) + output = self.pos_drop(output) + + # Encoder + stack = [] + for encoder_layer, downsample in zip(self.encoder_layers, self.downsamples): + output = encoder_layer(output, mask=mask) + stack.append(output) + output = downsample(output) + # Bottleneck + output = self.bottleneck(output, mask=mask) + + # Decoder + for decoder_layer, upsample in zip(self.decoder_layers, self.upsamples): + downsampled_output = stack.pop() + output = upsample(output) + + output = torch.cat([output, downsampled_output], -1) + output = decoder_layer(output, mask=mask) + + # Output Projection + output = self.output_proj(output) + if self.in_channels != self.out_channels: + x = self.conv_out(input) + return x + output + + +class UFormerModel(nn.Module): + """U-Former model with normalization and padding operations. + + Parameters + ---------- + patch_size : int + Size of the patch. Default: 256. + in_channels : int + Number of input channels. Default: 2. + out_channels : int, optional + Number of output channels. Default: None. + embedding_dim : int + Size of the feature embedding. Default: 32. + encoder_depths : tuple + Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2). + encoder_num_heads : tuple + Number of attention heads for each layer of the encoder of the U-former, from top to bottom. + Default: (1, 2, 4, 8). + bottleneck_depth : int + Default: 16. + bottleneck_num_heads : int + Default: 2. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float + Scale factor for the query and key projection vectors. + If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None. + drop_rate : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop_rate : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path_rate : float + Dropout rate for the stochastic depth regularization. Default: 0.1. + patch_norm : bool + Whether to use normalization for the patch embeddings. Default: True. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR + or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or + LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF. + shift_flag : bool + Whether to use shift operation in the local attention mechanism. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + normalized : bool + Whether to apply normalization before and denormalization after the forward pass. Default: True. + **kwargs: Other keyword arguments to pass to the parent constructor. + """ + + def __init__( + self, + patch_size: int = 256, + in_channels: int = 2, + out_channels: Optional[int] = None, + embedding_dim: int = 32, + encoder_depths: tuple[int, ...] = (2, 2, 2, 2), + encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8), + bottleneck_depth: int = 2, + bottleneck_num_heads: int = 16, + win_size: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.1, + patch_norm: bool = True, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.LEFF, + shift_flag: bool = True, + modulator: bool = False, + cross_modulator: bool = False, + normalized: bool = True, + ) -> None: + """Inits :class:`UFormer`. + + Parameters + ---------- + patch_size : int + Size of the patch. Default: 256. + in_channels : int + Number of input channels. Default: 2. + out_channels : int, optional + Number of output channels. Default: None. + embedding_dim : int + Size of the feature embedding. Default: 32. + encoder_depths : tuple + Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2). + encoder_num_heads : tuple + Number of attention heads for each layer of the encoder of the U-former, from top to bottom. + Default: (1, 2, 4, 8). + bottleneck_depth : int + Default: 16. + bottleneck_num_heads : int + Default: 2. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float + Scale factor for the query and key projection vectors. + If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None. + drop_rate : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop_rate : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path_rate : float + Dropout rate for the stochastic depth regularization. Default: 0.1. + patch_norm : bool + Whether to use normalization for the patch embeddings. Default: True. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR + or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or + LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF. + shift_flag : bool + Whether to use shift operation in the local attention mechanism. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + normalized : bool + Whether to apply normalization before and denormalization after the forward pass. Default: True. + **kwargs: Other keyword arguments to pass to the parent constructor. + """ + # pylint: disable=too-many-locals + super().__init__() + + self.uformer = UFormer( + patch_size, + in_channels, + out_channels, + embedding_dim, + encoder_depths, + encoder_num_heads, + bottleneck_depth, + bottleneck_num_heads, + win_size, + mlp_ratio, + qkv_bias, + qk_scale, + drop_rate, + attn_drop_rate, + drop_path_rate, + patch_norm, + token_projection, + token_mlp, + shift_flag, + modulator, + cross_modulator, + ) + self.normalized = normalized + self.padding_factor = win_size * (2 ** len(encoder_depths)) + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Performs forward pass of :class:`UFormer`. + + Parameters + ---------- + x : torch.Tensor + mask : torch.Tensor, optional + + Returns + ------- + torch.Tensor + """ + x, _, wpad, hpad = pad_to_square(x, self.padding_factor) + if self.normalized: + x, mean, std = norm(x) + x = self.uformer(x, mask) + if self.normalized: + x = unnorm(x, mean, std) + x = unpad_to_original(x, hpad, wpad) + return x diff --git a/direct/nn/transformers/utils.py b/direct/nn/transformers/utils.py new file mode 100644 index 00000000..220838be --- /dev/null +++ b/direct/nn/transformers/utils.py @@ -0,0 +1,223 @@ +# Copyright (c) DIRECT Contributors + +"""DIRECT module containing utility functions for the transformers models.""" + +from __future__ import annotations + +from math import ceil, floor + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import init + +__all__ = ["init_weights", "norm", "pad_to_divisible", "pad_to_square", "unnorm", "unpad_to_original", "DropoutPath"] + + +def pad_to_divisible(x: torch.Tensor, pad_size: tuple[int, ...]) -> tuple[torch.Tensor, tuple[tuple[int, int], ...]]: + """Pad the input tensor with zeros to make its spatial dimensions divisible by the specified pad size. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (*, spatial_1, spatial_2, ..., spatial_N), where spatial dimensions can vary in number. + pad_size : tuple[int, ...] + Patch size to make each spatial dimension divisible by. This is a tuple of integers for each spatial dimension. + + Returns + ------- + tuple + Containing the padded tensor and a tuple of tuples indicating the number of pixels padded in + each spatial dimension. + """ + pads = [] + for dim, p_dim in zip(x.shape[-len(pad_size) :], pad_size): + pad_before = (p_dim - dim % p_dim) % p_dim / 2 + pads.append((floor(pad_before), ceil(pad_before))) + + # Reverse and flatten pads to match torch's expected + # (pad_n_before, pad_n_after, ..., pad_1_before, pad_1_after) format + flat_pads = tuple(val for sublist in pads[::-1] for val in sublist) + x = F.pad(x, flat_pads) + + return x, tuple(pads) + + +def unpad_to_original(x: torch.Tensor, *pads: tuple[int, int]) -> torch.Tensor: + """Remove the padding added to the input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor with padded spatial dimensions. + pads : tuple[int, int] + A tuple of (pad_before, pad_after) for each spatial dimension. + + Returns + ------- + torch.Tensor + Tensor with the padding removed, matching the shape of the original input tensor before padding. + """ + slices = [slice(None)] * (x.ndim - len(pads)) # Keep the batch and channel dimensions + for i, (pad_before, pad_after) in enumerate(pads): + slices.append(slice(pad_before, x.shape[-len(pads) + i] - pad_after)) + + return x[tuple(slices)] + + +def pad_to_square( + inp: torch.Tensor, factor: float +) -> tuple[torch.Tensor, torch.Tensor, tuple[int, int], tuple[int, int]]: + """Pad a tensor to a square shape with a given factor. + + Parameters + ---------- + inp : torch.Tensor + The input tensor to pad to square shape. Expected shape is (\*, height, width). + factor : float + The factor to which the input tensor will be padded. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor, tuple[int, int], tuple[int, int]] + A tuple of two tensors, the first is the input tensor padded to a square shape, and the + second is the corresponding mask for the padded tensor. + + Examples + -------- + 1. + >>> x = torch.rand(1, 3, 224, 192) + >>> padded_x, mask, wpad, hpad = pad_to_square(x, factor=16.0) + >>> padded_x.shape, mask.shape + (torch.Size([1, 3, 224, 224]), torch.Size([1, 1, 224, 224])) + 2. + >>> x = torch.rand(3, 13, 2, 234, 180) + >>> padded_x, mask, wpad, hpad = pad_to_square(x, factor=16.0) + >>> padded_x.shape, wpad, hpad + (torch.Size([3, 13, 2, 240, 240]), (30, 30), (3, 3)) + """ + channels, h, w = inp.shape[-3:] + + # Calculate the maximum size and pad to the next multiple of the factor + x = int(ceil(max(h, w) / float(factor)) * factor) + + # Create a tensor of zeros with the maximum size and copy the input tensor into the center + img = torch.zeros(*inp.shape[:-3], channels, x, x, device=inp.device).type_as(inp) + mask = torch.zeros(*((1,) * (img.ndim - 3)), 1, x, x, device=inp.device).type_as(inp) + + # Compute the offset and copy the input tensor into the center of the zero tensor + offset_h = (x - h) // 2 + offset_w = (x - w) // 2 + hpad = (offset_h, offset_h + h) + wpad = (offset_w, offset_w + w) + img[..., hpad[0] : hpad[1], wpad[0] : wpad[1]] = inp.clone() + mask[..., hpad[0] : hpad[1], wpad[0] : wpad[1]].fill_(1.0) + # Return the padded tensor and the corresponding mask, and padding in spatial dimensions + return ( + img, + 1 - mask, + (wpad[0], wpad[1] - w + (1 if w % 2 != 0 else 0)), + (hpad[0], hpad[1] - h + (1 if h % 2 != 0 else 0)), + ) + + +def norm(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Normalize the input tensor by subtracting the mean and dividing by the standard deviation + across each channel and pixel for arbitrary spatial dimensions. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, *spatial_dims), where spatial_dims can vary in number (e.g., 2D, 3D, etc.). + + Returns + ------- + tuple + Containing the normalized tensor, mean tensor, and standard deviation tensor. + """ + # Flatten spatial dimensions and compute mean and std across them + spatial_dims = x.shape[2:] # Get all spatial dimensions + flattened = x.view(x.shape[0], x.shape[1], -1) # Flatten the spatial dimensions for mean/std calculation + + mean = flattened.mean(-1, keepdim=True).view(x.shape[0], x.shape[1], *([1] * len(spatial_dims))) + std = flattened.std(-1, keepdim=True).view(x.shape[0], x.shape[1], *([1] * len(spatial_dims))) + + # Normalize + x = (x - mean) / std + + return x, mean, std + + +def unnorm(x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: + """Denormalize the input tensor by multiplying by the standard deviation and adding the mean + for arbitrary spatial dimensions. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, *spatial_dims), where spatial_dims can vary in number. + mean : torch.Tensor + Mean tensor obtained during normalization. + std : torch.Tensor + Standard deviation tensor obtained during normalization. + + Returns + ------- + torch.Tensor + Tensor with the same shape as the original input tensor, but denormalized. + """ + return x * std + mean + + +def init_weights(m: nn.Module) -> None: + """Initializes the weights of the network using a truncated normal distribution. + + Parameters + ---------- + m : nn.Module + A module of the network whose weights need to be initialized. + """ + + if isinstance(m, nn.Linear): + init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + init.constant_(m.bias, 0) + init.constant_(m.weight, 1.0) + + +class DropoutPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + """Inits :class:`DropoutPath`. + + Parameters + ---------- + drop_prob : float + Probability of dropping a residual connection. Default: 0.0. + scale_by_keep : bool + Whether to scale the remaining activations by 1 / (1 - drop_prob) to maintain the expected value of + the activations. Default: True. + """ + super().__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + @staticmethod + def _dropout_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def forward(self, x): + return self._dropout_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"dropout_prob={round(self.drop_prob, 3):0.3f}" diff --git a/direct/nn/transformers/vit.py b/direct/nn/transformers/vit.py new file mode 100644 index 00000000..127fd072 --- /dev/null +++ b/direct/nn/transformers/vit.py @@ -0,0 +1,1323 @@ +# Copyright (c) DIRECT Contributors + +"""DIRECT Vision Transformer module. + +Implementation of Vision Transformer model [1, 2]_ in PyTorch. + +Code borrowed from [3]_ which uses code from timm [4]_. + +References +---------- +.. [1] Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, + M., Heigold, G., Gelly, S., Uszkoreit, J., Houlsby, N.: An Image is Worth 16x16 Words: + Transformers for Image Recognition at Scale, http://arxiv.org/abs/2010.11929, (2021). +.. [2] Steiner, A., Kolesnikov, A., Zhai, X., Wightman, R., Uszkoreit, J., Beyer, L.: How to train your ViT? Data, + Augmentation, and Regularization in Vision Transformers, http://arxiv.org/abs/2106.10270, (2022). +.. [3] https://github.com/facebookresearch/convit +.. [4] https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import init + +from direct.constants import COMPLEX_SIZE +from direct.nn.transformers.utils import DropoutPath, init_weights, norm, pad_to_divisible, unnorm, unpad_to_original +from direct.types import DirectEnum + +__all__ = ["VisionTransformer2D", "VisionTransformer3D"] + + +class VisionTransformerDimensionality(DirectEnum): + + TWO_DIMENSIONAL = "2D" + THREE_DIMENSIONAL = "3D" + + +class MLP(nn.Module): + """MLP layer with dropout and activation for Vision Transformer. + + Parameters + ---------- + in_features : int + Size of the input feature. + hidden_features : int, optional + Size of the hidden layer feature. If None, then hidden_features = in_features. Default: None. + out_features : int, optional + Size of the output feature. If None, then out_features = in_features. Default: None. + act_layer : nn.Module, optional + Activation layer to be used. Default: nn.GELU. + drop : float, optional + Dropout probability. Default: 0. + """ + + def __init__( + self, + in_features: int, + hidden_features: int = None, + out_features: int = None, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, + ) -> None: + """Inits :class:`MLP`. + + Parameters + ---------- + in_features : int + Size of the input feature. + hidden_features : int, optional + Size of the hidden layer feature. If None, then hidden_features = in_features. Default: None. + out_features : int, optional + Size of the output feature. If None, then out_features = in_features. Default: None. + act_layer : nn.Module, optional + Activation layer to be used. Default: nn.GELU. + drop : float, optional + Dropout probability. Default: 0. + """ + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.apply(init_weights) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`MLP`. + + Parameters + ---------- + x : torch.Tensor + Input tensor to the network. + + Returns + ------- + torch.Tensor + Output tensor of the network. + + """ + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class GPSA(nn.Module): + """Gated Positional Self-Attention module for Vision Transformer. + + Parameters + ---------- + dim : int + Dimensionality of the input embeddings. + num_heads : int + Number of attention heads. + qkv_bias : bool + If True, include bias terms in the query, key, and value projections. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout probability for attention weights. + proj_drop : float + Dropout probability for output tensor. + locality_strength : float + Strength of locality assumption in initialization. + use_local_init : bool + If True, use the locality-based initialization. + grid_size : tuple[int,int], optional + The size of the grid (height, width) for relative position encoding. + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + locality_strength: float = 1.0, + use_local_init: bool = True, + grid_size=None, + ) -> None: + """Inits :class:`GPSA`. + + Parameters + ---------- + dim : int + Dimensionality of the input embeddings. + num_heads : int + Number of attention heads. + qkv_bias : bool + If True, include bias terms in the query, key, and value projections. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout probability for attention weights. + proj_drop : float + Dropout probability for output tensor. + locality_strength : float + Strength of locality assumption in initialization. + use_local_init : bool + If True, use the locality-based initialization. + grid_size : tuple[int,int], optional + The size of the grid (height, width) for relative position encoding. + """ + super().__init__() + self.num_heads = num_heads + self.dim = dim + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.k = nn.Linear(dim, dim, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.pos_proj = nn.Linear(3, num_heads) + self.proj_drop = nn.Dropout(proj_drop) + self.locality_strength = locality_strength + self.gating_param = nn.Parameter(torch.ones(self.num_heads)) + self.apply(init_weights) + if use_local_init: + self.local_init(locality_strength=locality_strength) + self.current_grid_size = grid_size + + def get_attention(self, x: torch.Tensor) -> torch.Tensor: + """Compute the attention scores for each patch in x. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, N, C). + + Returns + ------- + torch.Tensor + Attention scores for each patch in x. + """ + B, N, C = x.shape + + k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + pos_score = self.pos_proj(self.get_rel_indices()).expand(B, -1, -1, -1).permute(0, 3, 1, 2) + patch_score = (q @ k.transpose(-2, -1)) * self.scale + patch_score = patch_score.softmax(dim=-1) + pos_score = pos_score.softmax(dim=-1) + + gating = self.gating_param.view(1, -1, 1, 1) + attn = (1.0 - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score + attn = attn / attn.sum(dim=-1).unsqueeze(-1) + attn = self.attn_drop(attn) + return attn + + @abstractmethod + def local_init(self, locality_strength: Optional[float] = 1.0) -> None: + pass + + @abstractmethod + def get_rel_indices(self) -> None: + pass + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`GPSA`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor: + """ + B, N, C = x.shape + + attn = self.get_attention(x) + v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GPSA2D(GPSA): + """Gated Positional Self-Attention module for Vision Transformer (2D variant). + + Parameters + ---------- + dim : int + Dimensionality of the input embeddings. + num_heads : int + Number of attention heads. + qkv_bias : bool + If True, include bias terms in the query, key, and value projections. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout probability for attention weights. + proj_drop : float + Dropout probability for output tensor. + locality_strength : float + Strength of locality assumption in initialization. + use_local_init : bool + If True, use the locality-based initialization. + grid_size : tuple[int,int], optional + The size of the grid (height, width) for relative position encoding. + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + locality_strength: float = 1.0, + use_local_init: bool = True, + grid_size=None, + ) -> None: + """Inits :class:`GPSA2D`. + + Parameters + ---------- + dim : int + Dimensionality of the input embeddings. + num_heads : int + Number of attention heads. + qkv_bias : bool + If True, include bias terms in the query, key, and value projections. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout probability for attention weights. + proj_drop : float + Dropout probability for output tensor. + locality_strength : float + Strength of locality assumption in initialization. + use_local_init : bool + If True, use the locality-based initialization. + grid_size : tuple[int,int], optional + The size of the grid (height, width) for relative position encoding. + """ + super().__init__( + dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + locality_strength=locality_strength, + use_local_init=use_local_init, + grid_size=grid_size, + ) + + def local_init(self, locality_strength: Optional[float] = 1.0) -> None: + """Initializes the positional projection weights with locality distance. + + Parameters + ---------- + locality_strength : float, optional + Determines how focused the attention is around its center. + """ + self.v.weight.data.copy_(torch.eye(self.dim)) + + kernel_size = int(self.num_heads**0.5) + center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2 + + for h1 in range(kernel_size): + for h2 in range(kernel_size): + position = h1 + kernel_size * h2 + self.pos_proj.weight.data[position, 2] = -1 + self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) + self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) + + self.pos_proj.weight.data *= locality_strength + + def get_rel_indices(self) -> None: + """Get relative indices for 2D grid of patches.""" + H, W = self.current_grid_size + N = H * W + + rel_indices = torch.zeros(1, N, N, 3) + + indx = torch.arange(W).view(1, -1) - torch.arange(W).view(-1, 1) + indx = indx.repeat(H, H) + indy = torch.arange(H).view(1, -1) - torch.arange(H).view(-1, 1) + indy = indy.repeat_interleave(W, dim=0).repeat_interleave(W, dim=1) + indd = indx**2 + indy**2 + + rel_indices[:, :, :, 2] = indd.unsqueeze(0) + rel_indices[:, :, :, 1] = indy.unsqueeze(0) + rel_indices[:, :, :, 0] = indx.unsqueeze(0) + + return rel_indices.to(self.v.weight.device) + + +class GPSA3D(GPSA): + """Gated Positional Self-Attention module for Vision Transformer (3D variant). + + Parameters + ---------- + dim : int + Dimensionality of the input embeddings. + num_heads : int + Number of attention heads. + qkv_bias : bool + If True, include bias terms in the query, key, and value projections. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout probability for attention weights. + proj_drop : float + Dropout probability for output tensor. + locality_strength : float + Strength of locality assumption in initialization. + use_local_init : bool + If True, use the locality-based initialization. + grid_size : tuple[int, int, int], optional + The size of the grid (depth, height, width) for relative position encoding. + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + locality_strength: float = 1.0, + use_local_init: bool = True, + grid_size=None, + ) -> None: + super().__init__( + dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + locality_strength=locality_strength, + use_local_init=use_local_init, + grid_size=grid_size, + ) + + def local_init(self, locality_strength: Optional[float] = 1.0) -> None: + """Initializes the positional projection weights with locality distance. + + Parameters + ---------- + locality_strength : float, optional + Determines how focused the attention is around its center. + """ + self.v.weight.data.copy_(torch.eye(self.dim)) + + kernel_size = int(self.num_heads ** (1 / 3)) + center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2 + + for h1 in range(kernel_size): + for h2 in range(kernel_size): + for h3 in range(kernel_size): + position = h1 + kernel_size * (h2 + kernel_size * h3) + self.pos_proj.weight.data[position, 2] = -1 + self.pos_proj.weight.data[position, 1] = 2 * (h2 - center) + self.pos_proj.weight.data[position, 0] = 2 * (h3 - center) + + self.pos_proj.weight.data *= locality_strength + + def get_rel_indices(self) -> torch.Tensor: + """Get relative indices for 3D grid of patches.""" + D, H, W = self.current_grid_size + N = D * H * W + rel_indices = torch.zeros(1, N, N, 3) + + indz = torch.arange(D).view(1, -1) - torch.arange(D).view(-1, 1) + indz = indz.repeat(H * W, H * W) + + indx = torch.arange(W).view(1, -1) - torch.arange(W).view(-1, 1) + indx = indx.repeat(D * H, D * H) + + indy = torch.arange(H).view(1, -1) - torch.arange(H).view(-1, 1) + indy = indy.repeat(D * W, D * W) + + indd = indz**2 + indx**2 + indy**2 + rel_indices[:, :, :, 2] = indd.unsqueeze(0) + rel_indices[:, :, :, 1] = indy.unsqueeze(0) + rel_indices[:, :, :, 0] = indx.unsqueeze(0) + + return rel_indices.to(self.v.weight.device) + + +class MHSA(nn.Module): + """Multi-Head Self-Attention (MHSA) module. + + Parameters + ---------- + dim : int + Number of input features. + num_heads : int + Number of heads in the attention mechanism. Default is 8. + qkv_bias : bool + If True, bias is added to the query, key and value projections. Default is False. + qk_scale : float or None + Scaling factor for the query-key dot product. If None, it is set to + head_dim ** -0.5 where head_dim = dim // num_heads. Default is None. + attn_drop : float + Dropout rate for the attention weights. Default is 0. + proj_drop : float + Dropout rate for the output of the module. Default is 0. + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + """Inits :class:`MHSA`. + + Parameters + ---------- + dim : int + Number of input features. + num_heads : int + Number of heads in the attention mechanism. Default is 8. + qkv_bias : bool + If True, bias is added to the query, key and value projections. Default is False. + qk_scale : float or None + Scaling factor for the query-key dot product. If None, it is set to + head_dim ** -0.5 where head_dim = dim // num_heads. Default is None. + attn_drop : float + Dropout rate for the attention weights. Default is 0. + proj_drop : float + Dropout rate for the output of the module. Default is 0. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.apply(init_weights) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`MHSA`. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, N, C). + + Returns + ------- + torch.Tensor + Output tensor of shape (B, N, C). + """ + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class VisionTransformerBlock(nn.Module): + """A single transformer block used in the VisionTransformer model. + + Parameters + ---------- + dimensionality : VisionTransformerDimensionality + The dimensionality of the input data. + dim : int + The feature dimension. + num_heads : int + The number of attention heads. + mlp_ratio : float, optional + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool, optional + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float, optional + The scale factor for the query-key dot product. Default: None. + drop : float, optional + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop : float, optional + The dropout probability for the attention layer. Default: 0.0. + dropout_path : float, optional + The dropout probability for the dropout path. Default: 0.0. + act_layer : nn.Module, optional + The activation layer used in the MLP. Default: nn.GELU. + norm_layer : nn.Module, optional + The normalization layer used in the block. Default: nn.LayerNorm. + use_gpsa : bool, optional + Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True. + **kwargs: Additional arguments for the attention layer. + """ + + def __init__( + self, + dimensionality: VisionTransformerDimensionality, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop: float = 0.0, + attn_drop: float = 0.0, + dropout_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + use_gpsa: bool = True, + **kwargs, + ) -> None: + """Inits :class:`VisionTransformerBlock`. + + Parameters + ---------- + dimensionality : VisionTransformerDimensionality + The dimensionality of the input data. + dim : int + The feature dimension. + num_heads : int + The number of attention heads. + mlp_ratio : float, optional + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool, optional + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float, optional + The scale factor for the query-key dot product. Default: None. + drop : float, optional + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop : float, optional + The dropout probability for the attention layer. Default: 0.0. + dropout_path : float, optional + The dropout probability for the dropout path. Default: 0.0. + act_layer : nn.Module, optional + The activation layer used in the MLP. Default: nn.GELU. + norm_layer : nn.Module, optional + The normalization layer used in the block. Default: nn.LayerNorm. + use_gpsa : bool, optional + Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True. + **kwargs: Additional arguments for the attention layer. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.use_gpsa = use_gpsa + if self.use_gpsa: + self.attn = (GPSA2D if dimensionality == VisionTransformerDimensionality.TWO_DIMENSIONAL else GPSA3D)( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + **kwargs, + ) + else: + self.attn = MHSA( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + **kwargs, + ) + self.dropout_path = DropoutPath(dropout_path) if dropout_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x: torch.Tensor, grid_size: tuple[int, int]) -> torch.Tensor: + """Forward pass for the :class:`VisionTransformerBlock`. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + grid_size : tuple[int, int] + The size of the grid used by the attention layer. + + Returns + ------- + torch.Tensor: The output tensor. + """ + self.attn.current_grid_size = grid_size + x = x + self.dropout_path(self.attn(self.norm1(x))) + x = x + self.dropout_path(self.mlp(self.norm2(x))) + + return x + + +class PatchEmbedding(nn.Module): + """Image to Patch Embedding.""" + + def __init__( + self, patch_size, in_channels, embedding_dim, dimensionality: VisionTransformerDimensionality + ) -> None: + """Inits :class:`PatchEmbedding` module for Vision Transformer. + + Parameters + ---------- + patch_size : int or tuple[int, int] + The patch size. If an int is provided, the patch will be a square. + in_channels : int + Number of input channels. + embedding_dim : int + Dimension of the output embedding. + dimensionality : VisionTransformerDimensionality + The dimensionality of the input data. + """ + super().__init__() + self.proj = (nn.Conv2d if dimensionality == VisionTransformerDimensionality.TWO_DIMENSIONAL else nn.Conv3d)( + in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size + ) + self.apply(init_weights) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`PatchEmbedding`. + + Parameters + ---------- + x : torch.Tensor + + Returns + ------- + torch.Tensor + Patch embedding. + """ + x = self.proj(x) + return x + + +class VisionTransformer(nn.Module): + """Vision Transformer model. + + Parameters + ---------- + dimensionality : VisionTransformerDimensionality + The dimensionality of the input data. + average_img_size : int or tuple[int, int] or tuple[int, int, int] + The average size of the input image. If an int is provided, this will be determined by the + `dimensionality`, i.e., (average_img_size, average_img_size) for 2D and + (average_img_size, average_img_size, average_img_size) for 3D. Default: 320. + patch_size : int or tuple[int, int] or tuple[int, int, int] + The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e., + (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16. + in_channels : int + Number of input channels. Default: COMPLEX_SIZE. + out_channels : int or None + Number of output channels. If None, this will be set to `in_channels`. Default: None. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa: bool + Whether to use GPSA layer. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + + def __init__( + self, + dimensionality: VisionTransformerDimensionality, + average_img_size: int | tuple[int, int] | tuple[int, int, int] = 320, + patch_size: int | tuple[int, int] | tuple[int, int, int] = 16, + in_channels: int = COMPLEX_SIZE, + out_channels: int = None, + embedding_dim: int = 64, + depth: int = 8, + num_heads: int = 9, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + use_gpsa: bool = True, + locality_strength: float = 1.0, + use_pos_embedding: bool = True, + normalized: bool = True, + ) -> None: + """Inits :class:`VisionTransformer`. + + Parameters + ---------- + dimensionality : VisionTransformerDimensionality + The dimensionality of the input data. + average_img_size : int or tuple[int, int] or tuple[int, int, int] + The average size of the input image. If an int is provided, this will be determined by the + `dimensionality`, i.e., (average_img_size, average_img_size) for 2D and + (average_img_size, average_img_size, average_img_size) for 3D. Default: 320. + patch_size : int or tuple[int, int] or tuple[int, int, int] + The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e., + (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16. + in_channels : int + Number of input channels. Default: COMPLEX_SIZE. + out_channels : int or None + Number of output channels. If None, this will be set to `in_channels`. Default: None. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa: bool + Whether to use GPSA layer. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + # pylint: disable=too-many-locals + super().__init__() + + self.dimensionality = dimensionality + + self.depth = depth + embedding_dim *= num_heads + self.num_features = embedding_dim # num_features for consistency with other models + self.locality_strength = locality_strength + self.use_pos_embedding = use_pos_embedding + + if isinstance(average_img_size, int): + if self.dimensionality == VisionTransformerDimensionality.TWO_DIMENSIONAL: + img_size = (average_img_size, average_img_size) + else: + img_size = (average_img_size, average_img_size, average_img_size) + else: + if len(average_img_size) != ( + 2 if self.dimensionality == VisionTransformerDimensionality.TWO_DIMENSIONAL else 3 + ): + raise ValueError( + f"average_img_size should have length 2 for 2D and 3 for 3D, got {len(average_img_size)}." + ) + img_size = average_img_size + + if isinstance(patch_size, int): + if self.dimensionality == VisionTransformerDimensionality.TWO_DIMENSIONAL: + self.patch_size = (patch_size, patch_size) + else: + self.patch_size = (patch_size, patch_size, patch_size) + else: + if len(patch_size) != (2 if self.dimensionality == VisionTransformerDimensionality.TWO_DIMENSIONAL else 3): + raise ValueError(f"patch_size should have length 2 for 2D and 3 for 3D, got {len(patch_size)}.") + self.patch_size = patch_size + + self.in_channels = in_channels + self.out_channels = out_channels if out_channels else in_channels + + self.patch_embed = PatchEmbedding( + patch_size=self.patch_size, + in_channels=in_channels, + embedding_dim=embedding_dim, + dimensionality=dimensionality, + ) + + self.pos_drop = nn.Dropout(p=drop_rate) + + if self.use_pos_embedding: + self.pos_embed = nn.Parameter( + torch.zeros(1, embedding_dim, *[img_size[i] // self.patch_size[i] for i in range(len(img_size))]) + ) + + init.trunc_normal_(self.pos_embed, std=0.02) + + dpr = [x.item() for x in torch.linspace(0, dropout_path_rate, depth)] # stochastic depth decay rule + + self.blocks = nn.ModuleList( + [ + VisionTransformerBlock( + dimensionality=dimensionality, + dim=embedding_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + dropout_path=dpr[i], + norm_layer=nn.LayerNorm, + use_gpsa=use_gpsa, + **({"locality_strength": locality_strength} if use_gpsa else {}), + ) + for i in range(depth) + ] + ) + + self.normalized = normalized + + self.norm = nn.LayerNorm(embedding_dim) + # head + self.feature_info = [{"num_chs": embedding_dim, "reduction": 0, "module": "head"}] + self.head = nn.Linear(self.num_features, self.out_channels * np.prod(self.patch_size)) + + self.head.apply(init_weights) + + def get_head(self) -> nn.Module: + """Returns the head of the model. + + Returns + ------- + nn.Module + """ + return self.head + + def reset_head(self) -> None: + """Resets the head of the model.""" + self.head = nn.Linear(self.num_features, self.out_channels * np.prod(self.patch_size)) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the feature extraction part of the model. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + + Returns + ------- + torch.Tensor + """ + x = self.patch_embed(x) + size = x.shape[2:] + + if self.use_pos_embedding: + pos_embed = F.interpolate( + self.pos_embed, + size=size, + mode=( + "bilinear" + if self.dimensionality == VisionTransformerDimensionality.TWO_DIMENSIONAL + else "trilinear" + ), + align_corners=False, + ) + x = x + pos_embed + + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + for _, block in enumerate(self.blocks): + x = block(x, size) + + x = self.norm(x) + + return x + + @abstractmethod + def seq2img(self, x: torch.Tensor, img_size: tuple[int, ...]) -> torch.Tensor: + """Converts the sequence patches tensor to an image tensor. + + Parameters + ---------- + x : torch.Tensor + The sequence tensor. + img_size : tuple[int, ...] + The size of the image tensor. + + Returns + ------- + torch.Tensor + The image tensor. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`VisionTransformer`. + + Parameters + ---------- + x : torch.Tensor + + Returns + ------- + torch.Tensor + """ + x, pads = pad_to_divisible(x, self.patch_size) + + size = x.shape[2:] + + if self.normalized: + x, mean, std = norm(x) + + x = self.forward_features(x) + x = self.head(x) + x = self.seq2img(x, size) + + if self.normalized: + x = unnorm(x, mean, std) + + x = unpad_to_original(x, *pads) + + return x + + +class VisionTransformer2D(VisionTransformer): + """Vision Transformer model for 2D data. + + Parameters + ---------- + average_img_size : int or tuple[int, int] + The average size of the input image. If an int is provided, this will be determined by the + `dimensionality`, i.e., (average_img_size, average_img_size) for 2D and + (average_img_size, average_img_size, average_img_size) for 3D. Default: 320. + patch_size : int or tuple[int, int] + The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e., + (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16. + in_channels : int + Number of input channels. Default: COMPLEX_SIZE. + out_channels : int or None + Number of output channels. If None, this will be set to `in_channels`. Default: None. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa: bool + Whether to use GPSA layer. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + + def __init__( + self, + average_img_size: int | tuple[int, int] = 320, + patch_size: int | tuple[int, int] = 16, + in_channels: int = COMPLEX_SIZE, + out_channels: int = None, + embedding_dim: int = 64, + depth: int = 8, + num_heads: int = 9, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + use_gpsa: bool = True, + locality_strength: float = 1.0, + use_pos_embedding: bool = True, + normalized: bool = True, + ) -> None: + """Inits :class:`VisionTransformer2D`. + + Parameters + ---------- + average_img_size : int or tuple[int, int] + The average size of the input image. If an int is provided, this will be defined as + (average_img_size, average_img_size). Default: 320. + patch_size : int or tuple[int, int] + The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size). Default: 16. + in_channels : int + Number of input channels. Default: COMPLEX_SIZE. + out_channels : int or None + Number of output channels. If None, this will be set to `in_channels`. Default: None. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa: bool + Whether to use GPSA layer. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + # pylint: disable=too-many-locals + super().__init__( + dimensionality=VisionTransformerDimensionality.TWO_DIMENSIONAL, + average_img_size=average_img_size, + patch_size=patch_size, + in_channels=in_channels, + out_channels=out_channels, + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + dropout_path_rate=dropout_path_rate, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + ) + + def seq2img(self, x: torch.Tensor, img_size: tuple[int, ...]) -> torch.Tensor: + """Converts the sequence patches tensor to an image tensor. + + Parameters + ---------- + x : torch.Tensor + The sequence tensor. + img_size : tuple[int, ...] + The size of the image tensor. + + Returns + ------- + torch.Tensor + The image tensor. + """ + x = x.view(x.shape[0], x.shape[1], self.out_channels, self.patch_size[0], self.patch_size[1]) + x = x.chunk(x.shape[1], dim=1) + x = torch.cat(x, dim=4).permute(0, 1, 2, 4, 3) + x = x.chunk(img_size[0] // self.patch_size[0], dim=3) + x = torch.cat(x, dim=4).permute(0, 1, 2, 4, 3).squeeze(1) + + return x + + +class VisionTransformer3D(VisionTransformer): + """Vision Transformer model for 3D data. + + Parameters + ---------- + average_img_size : int or tuple[int, int, int] + The average size of the input image. If an int is provided, this will be defined as + (average_img_size, average_img_size, average_img_size). Default: 320. + patch_size : int or tuple[int, int, int] + The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size, patch_size). + Default: 16. + in_channels : int + Number of input channels. Default: COMPLEX_SIZE. + out_channels : int or None + Number of output channels. If None, this will be set to `in_channels`. Default: None. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa: bool + Whether to use GPSA layer. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + + def __init__( + self, + average_img_size: int | tuple[int, int, int] = 320, + patch_size: int | tuple[int, int, int] = 16, + in_channels: int = COMPLEX_SIZE, + out_channels: int = None, + embedding_dim: int = 64, + depth: int = 8, + num_heads: int = 9, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + use_gpsa: bool = True, + locality_strength: float = 1.0, + use_pos_embedding: bool = True, + normalized: bool = True, + ) -> None: + """Inits :class:`VisionTransformer3D`. + + Parameters + ---------- + average_img_size : int or tuple[int, int, int] + The average size of the input image. If an int is provided, this will be defined as + (average_img_size, average_img_size, average_img_size). Default: 320. + patch_size : int or tuple[int, int, int] + The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size, patch_size). + Default: 16. + in_channels : int + Number of input channels. Default: COMPLEX_SIZE. + out_channels : int or None + Number of output channels. If None, this will be set to `in_channels`. Default: None. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa: bool + Whether to use GPSA layer. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + # pylint: disable=too-many-locals + super().__init__( + dimensionality=VisionTransformerDimensionality.THREE_DIMENSIONAL, + average_img_size=average_img_size, + patch_size=patch_size, + in_channels=in_channels, + out_channels=out_channels, + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + dropout_path_rate=dropout_path_rate, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + ) + + def seq2img(self, x: torch.Tensor, img_size: tuple[int, ...]) -> torch.Tensor: + """Converts the sequence of 3D patches to a 3D image tensor. + + Parameters + ---------- + x : torch.Tensor + The sequence tensor, where each entry corresponds to a flattened 3D patch. + img_size : tuple of ints + The size of the 3D image tensor (depth, height, width). + + Returns + ------- + torch.Tensor + The reconstructed 3D image tensor. + """ + # Reshape the sequence into patches of shape (batch, num_patches, out_channels, D, H, W) + x = x.view( + x.shape[0], x.shape[1], self.out_channels, self.patch_size[0], self.patch_size[1], self.patch_size[2] + ) + + # Chunk along the sequence dimension (depth, height, width) + depth_chunks = img_size[0] // self.patch_size[0] # Number of chunks along depth + height_chunks = img_size[1] // self.patch_size[1] # Number of chunks along height + width_chunks = img_size[2] // self.patch_size[2] # Number of chunks along width + + # First, chunk along the sequence dimension (width axis) + x = torch.cat(x.chunk(width_chunks, dim=1), dim=5).permute(0, 1, 2, 3, 4, 5) + + # Now, chunk along the height axis + x = torch.cat(x.chunk(height_chunks, dim=1), dim=4).permute(0, 1, 2, 3, 4, 5) + + # Finally, chunk along the depth axis + x = torch.cat(x.chunk(depth_chunks, dim=1), dim=3).permute(0, 1, 2, 3, 4, 5).squeeze(1) + + return x diff --git a/setup.py b/setup.py index 9c1c5f14..4f0d10dd 100644 --- a/setup.py +++ b/setup.py @@ -59,12 +59,13 @@ def finalize_options(self): "h5py==3.11.0", "omegaconf==2.3.0", "torch>=2.2.0", - "torchvision==0.18.0", + "torchvision", "scikit-image>=0.19.0", "scikit-learn>=1.0.1", "tensorboard>=2.7.0", "tqdm", "protobuf==3.20.2", + "einops", ], extras_require={ "dev": [ diff --git a/tests/tests_nn/test_transformers.py b/tests/tests_nn/test_transformers.py new file mode 100644 index 00000000..2d33ef98 --- /dev/null +++ b/tests/tests_nn/test_transformers.py @@ -0,0 +1,275 @@ +# Copyright (c) DIRECT Contributors + +"""Tests for transformers models.""" + +import pytest +import torch + +from direct.nn.transformers.transformers import * +from direct.nn.transformers.uformer import AttentionTokenProjectionType, LeWinTransformerMLPTokenType, UFormerModel +from direct.nn.transformers.vit import VisionTransformer2D, VisionTransformer3D + + +def create_input(shape): + data = torch.rand(shape).float() + + return data + + +@pytest.mark.parametrize( + "shape", + [ + [3, 2, 32, 32], + [3, 2, 16, 16], + ], +) +@pytest.mark.parametrize( + "embedding_dim", + [20], +) +@pytest.mark.parametrize( + "patch_size", + [140], +) +@pytest.mark.parametrize( + "encoder_depths, encoder_num_heads, bottleneck_depth, bottleneck_num_heads", + [ + [(2, 2, 2), (1, 2, 4), 1, 8], + [(2, 2, 2, 2), (1, 2, 4, 8), 2, 8], + ], +) +@pytest.mark.parametrize( + "patch_norm", + [True, False], +) +@pytest.mark.parametrize( + "win_size", + [8], +) +@pytest.mark.parametrize( + "mlp_ratio", + [2], +) +@pytest.mark.parametrize( + "qkv_bias", + [True, False], +) +@pytest.mark.parametrize( + "qk_scale", + [None, 0.5], +) +@pytest.mark.parametrize( + "token_projection", + [AttentionTokenProjectionType.LINEAR, AttentionTokenProjectionType.CONV], +) +@pytest.mark.parametrize( + "token_mlp", + [LeWinTransformerMLPTokenType.MLP, LeWinTransformerMLPTokenType.LEFF], +) +def test_uformer( + shape, + patch_size, + embedding_dim, + encoder_depths, + encoder_num_heads, + bottleneck_depth, + bottleneck_num_heads, + win_size, + mlp_ratio, + patch_norm, + qkv_bias, + qk_scale, + token_projection, + token_mlp, +): + model = UFormerModel( + patch_size=patch_size, + in_channels=2, + embedding_dim=embedding_dim, + encoder_depths=encoder_depths, + encoder_num_heads=encoder_num_heads, + bottleneck_depth=bottleneck_depth, + bottleneck_num_heads=bottleneck_num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + patch_norm=patch_norm, + token_projection=token_projection, + token_mlp=token_mlp, + ) + data = create_input(shape).cpu() + out = model(data) + assert list(out.shape) == shape + + +@pytest.mark.parametrize( + "shape, average_img_size", + [ + [[1, 3, 128, 128], 128], + [[3, 2, 64, 50], (64, 50)], + ], +) +@pytest.mark.parametrize( + "patch_size", + [16, 8, (16, 10)], +) +@pytest.mark.parametrize( + "embedding_dim", + [6, 12], +) +@pytest.mark.parametrize( + "depth", + [2, 4], +) +@pytest.mark.parametrize( + "num_heads", + [3, 4], +) +@pytest.mark.parametrize( + "mlp_ratio", + [4.0, 2.0], +) +@pytest.mark.parametrize( + "qkv_bias", + [True, False], +) +@pytest.mark.parametrize( + "qk_scale", + [None, 0.5], +) +@pytest.mark.parametrize( + "use_gpsa", + [True, False], +) +@pytest.mark.parametrize( + "locality_strength", + [0.5], +) +@pytest.mark.parametrize( + "use_pos_embedding", + [True, False], +) +@pytest.mark.parametrize( + "normalized", + [True, False], +) +def test_vision_transformer_2d( + shape, + average_img_size, + patch_size, + embedding_dim, + depth, + num_heads, + mlp_ratio, + qkv_bias, + qk_scale, + use_gpsa, + locality_strength, + use_pos_embedding, + normalized, +): + model = VisionTransformer2D( + average_img_size=average_img_size, + patch_size=patch_size, + in_channels=shape[1], + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + ) + data = create_input(shape).cpu() + out = model(data) + assert list(out.shape) == [shape[0], shape[1], shape[2], shape[3]] + + +@pytest.mark.parametrize( + "shape, average_img_size", + [ + [[1, 3, 64, 64, 64], 64], + [[2, 2, 32, 32, 32], (32, 32, 32)], + ], +) +@pytest.mark.parametrize( + "patch_size", + [8, (8, 6, 8)], +) +@pytest.mark.parametrize( + "embedding_dim", + [8, 16], +) +@pytest.mark.parametrize( + "depth", + [4, 8], +) +@pytest.mark.parametrize( + "num_heads", + [6], +) +@pytest.mark.parametrize( + "mlp_ratio", + [4.0], +) +@pytest.mark.parametrize( + "qkv_bias", + [True, False], +) +@pytest.mark.parametrize( + "qk_scale", + [None, 0.5], +) +@pytest.mark.parametrize( + "use_gpsa", + [True, False], +) +@pytest.mark.parametrize( + "locality_strength", + [1.0], +) +@pytest.mark.parametrize( + "use_pos_embedding", + [True, False], +) +@pytest.mark.parametrize( + "normalized", + [True, False], +) +def test_vision_transformer_3d( + shape, + average_img_size, + patch_size, + embedding_dim, + depth, + num_heads, + mlp_ratio, + qkv_bias, + qk_scale, + use_gpsa, + locality_strength, + use_pos_embedding, + normalized, +): + model = VisionTransformer3D( + average_img_size=average_img_size, + patch_size=patch_size, + in_channels=shape[1], + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + ) + data = create_input(shape).cpu() + out = model(data) + assert list(out.shape) == [shape[0], shape[1], shape[2], shape[3], shape[4]] diff --git a/tests/tests_nn/test_transformers_engine.py b/tests/tests_nn/test_transformers_engine.py new file mode 100644 index 00000000..66b5f758 --- /dev/null +++ b/tests/tests_nn/test_transformers_engine.py @@ -0,0 +1,674 @@ +# Copyright (c) DIRECT Contributors + +"""Tests for `direct.nn.transformers.transformers_engine` module.""" + +import functools + +import numpy as np +import pytest +import torch + +from direct.config.defaults import DefaultConfig, FunctionConfig, LossConfig, TrainingConfig, ValidationConfig +from direct.data.transforms import fft2, ifft2 +from direct.nn.transformers.config import ( + ImageDomainMRIUFormerConfig, + ImageDomainMRIViT2DConfig, + ImageDomainMRIViT3DConfig, + KSpaceDomainMRIViT2DConfig, + KSpaceDomainMRIViT3DConfig, +) +from direct.nn.transformers.transformers import ( + ImageDomainMRIUFormer, + ImageDomainMRIViT2D, + ImageDomainMRIViT3D, + KSpaceDomainMRIViT2D, + KSpaceDomainMRIViT3D, +) +from direct.nn.transformers.transformers_engine import ( + ImageDomainMRIUFormerEngine, + ImageDomainMRIViT2DEngine, + ImageDomainMRIViT3DEngine, + KSpaceDomainMRIViT2DEngine, + KSpaceDomainMRIViT3DEngine, +) +from direct.nn.transformers.uformer import AttentionTokenProjectionType, LeWinTransformerMLPTokenType + + +def create_sample(shape, **kwargs): + sample = dict() + sample["masked_kspace"] = torch.from_numpy(np.random.randn(*shape)).float() + sample["kspace"] = torch.from_numpy(np.random.randn(*shape)).float() + sample["sensitivity_map"] = torch.from_numpy(np.random.randn(*shape)).float() + for k, v in locals()["kwargs"].items(): + sample[k] = v + return sample + + +@pytest.mark.parametrize( + "shape", + [(4, 3, 10, 16, 2), (5, 1, 10, 12, 2)], +) +@pytest.mark.parametrize( + "loss_fns", + [["l1_loss", "kspace_nmse_loss", "kspace_nmae_loss"]], +) +@pytest.mark.parametrize( + "embedding_dim", + [20], +) +@pytest.mark.parametrize( + "patch_size", + [140], +) +@pytest.mark.parametrize( + "encoder_depths, encoder_num_heads, bottleneck_depth, bottleneck_num_heads", + [ + [(2, 2, 2), (1, 2, 4), 1, 8], + ], +) +@pytest.mark.parametrize( + "patch_norm", + [True], +) +@pytest.mark.parametrize( + "win_size", + [8], +) +@pytest.mark.parametrize( + "mlp_ratio", + [2], +) +@pytest.mark.parametrize( + "qkv_bias", + [False], +) +@pytest.mark.parametrize( + "qk_scale", + [0.5], +) +@pytest.mark.parametrize( + "token_projection", + [AttentionTokenProjectionType.CONV], +) +@pytest.mark.parametrize( + "token_mlp", + [LeWinTransformerMLPTokenType.MLP], +) +def test_image_uformer_engine( + shape, + loss_fns, + embedding_dim, + patch_size, + encoder_depths, + encoder_num_heads, + bottleneck_depth, + bottleneck_num_heads, + patch_norm, + win_size, + mlp_ratio, + qkv_bias, + qk_scale, + token_projection, + token_mlp, +): + # Operators + forward_operator = functools.partial(fft2, centered=True) + backward_operator = functools.partial(ifft2, centered=True) + # Configs + loss_config = LossConfig(losses=[FunctionConfig(loss) for loss in loss_fns]) + training_config = TrainingConfig(loss=loss_config) + validation_config = ValidationConfig(crop=None) + model_config = ImageDomainMRIUFormerConfig( + patch_size=patch_size, + embedding_dim=embedding_dim, + encoder_depths=encoder_depths, + encoder_num_heads=encoder_num_heads, + bottleneck_depth=bottleneck_depth, + bottleneck_num_heads=bottleneck_num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + patch_norm=patch_norm, + token_projection=token_projection, + token_mlp=token_mlp, + ) + config = DefaultConfig(training=training_config, validation=validation_config, model=model_config) + # Models + model = ImageDomainMRIUFormer( + forward_operator, + backward_operator, + patch_size=model_config.patch_size, + embedding_dim=model_config.embedding_dim, + encoder_depths=model_config.encoder_depths, + encoder_num_heads=model_config.encoder_num_heads, + bottleneck_depth=model_config.bottleneck_depth, + bottleneck_num_heads=model_config.bottleneck_num_heads, + win_size=model_config.win_size, + mlp_ratio=model_config.mlp_ratio, + qkv_bias=model_config.qkv_bias, + qk_scale=model_config.qk_scale, + patch_norm=model_config.patch_norm, + token_projection=model_config.token_projection, + token_mlp=model_config.token_mlp, + ) + sensitivity_model = torch.nn.Conv2d(2, 2, kernel_size=1) + # Define engine + engine = ImageDomainMRIUFormerEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 2 + # Test _do_iteration function with a single data batch + data = create_sample( + shape, + sampling_mask=torch.from_numpy(np.random.randn(1, 1, shape[2], shape[3], 1)).bool(), + target=torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3])).float(), + scaling_factor=torch.ones(shape[0]), + ) + loss_fns = engine.build_loss() + out = engine._do_iteration(data, loss_fns) + out.output_image.shape == (shape[0],) + tuple(shape[2:-1]) + + +@pytest.mark.parametrize( + "shape", + [(4, 3, 10, 16, 2), (5, 1, 10, 12, 2)], +) +@pytest.mark.parametrize( + "loss_fns", + [["l1_loss", "kspace_nmse_loss", "kspace_nmae_loss"]], +) +@pytest.mark.parametrize( + "patch_size", + [8, (8, 10)], +) +@pytest.mark.parametrize( + "embedding_dim", + [8], +) +@pytest.mark.parametrize( + "depth", + [4], +) +@pytest.mark.parametrize( + "num_heads", + [6], +) +@pytest.mark.parametrize( + "mlp_ratio", + [4.0], +) +@pytest.mark.parametrize( + "qkv_bias", + [False], +) +@pytest.mark.parametrize( + "qk_scale", + [None], +) +@pytest.mark.parametrize( + "use_gpsa", + [False], +) +@pytest.mark.parametrize( + "locality_strength", + [1.0], +) +@pytest.mark.parametrize( + "use_pos_embedding", + [True], +) +@pytest.mark.parametrize( + "normalized", + [True], +) +def test_image_vit2d_engine( + shape, + loss_fns, + patch_size, + embedding_dim, + depth, + num_heads, + mlp_ratio, + qkv_bias, + qk_scale, + use_gpsa, + locality_strength, + use_pos_embedding, + normalized, +): + # Operators + forward_operator = functools.partial(fft2, centered=True) + backward_operator = functools.partial(ifft2, centered=True) + # Configs + loss_config = LossConfig(losses=[FunctionConfig(loss) for loss in loss_fns]) + training_config = TrainingConfig(loss=loss_config) + validation_config = ValidationConfig(crop=None) + model_config = ImageDomainMRIViT2DConfig( + patch_size=patch_size, + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + ) + config = DefaultConfig(training=training_config, validation=validation_config, model=model_config) + # Models + model = ImageDomainMRIViT2D( + forward_operator, + backward_operator, + patch_size=model_config.patch_size, + embedding_dim=model_config.embedding_dim, + depth=model_config.depth, + num_heads=model_config.num_heads, + mlp_ratio=model_config.mlp_ratio, + qkv_bias=model_config.qkv_bias, + qk_scale=model_config.qk_scale, + use_gpsa=model_config.use_gpsa, + locality_strength=model_config.locality_strength, + use_pos_embedding=model_config.use_pos_embedding, + normalized=model_config.normalized, + ) + sensitivity_model = torch.nn.Conv2d(2, 2, kernel_size=1) + # Define engine + engine = ImageDomainMRIViT2DEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 2 + # Test _do_iteration function with a single data batch + data = create_sample( + shape, + sampling_mask=torch.from_numpy(np.random.randn(1, 1, shape[2], shape[3], 1)).bool(), + target=torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3])).float(), + ) + loss_fns = engine.build_loss() + out = engine._do_iteration(data, loss_fns) + out.output_image.shape == (shape[0],) + tuple(shape[2:-1]) + + +@pytest.mark.parametrize( + "shape", + [(4, 3, 10, 16, 2), (5, 1, 10, 12, 2)], +) +@pytest.mark.parametrize( + "loss_fns", + [["l1_loss", "kspace_nmse_loss", "kspace_nmae_loss"]], +) +@pytest.mark.parametrize( + "patch_size", + [(10, 10)], +) +@pytest.mark.parametrize( + "embedding_dim", + [8], +) +@pytest.mark.parametrize( + "depth", + [4], +) +@pytest.mark.parametrize( + "num_heads", + [6], +) +@pytest.mark.parametrize( + "mlp_ratio", + [4.0], +) +@pytest.mark.parametrize( + "qkv_bias", + [False], +) +@pytest.mark.parametrize( + "qk_scale", + [None], +) +@pytest.mark.parametrize( + "use_gpsa", + [True], +) +@pytest.mark.parametrize( + "locality_strength", + [1.0], +) +@pytest.mark.parametrize( + "use_pos_embedding", + [True], +) +@pytest.mark.parametrize( + "normalized", + [True], +) +@pytest.mark.parametrize( + "compute_per_coil", + [True, False], +) +def test_kspace_vit2d_engine( + shape, + loss_fns, + patch_size, + embedding_dim, + depth, + num_heads, + mlp_ratio, + qkv_bias, + qk_scale, + use_gpsa, + locality_strength, + use_pos_embedding, + normalized, + compute_per_coil, +): + # Operators + forward_operator = functools.partial(fft2, centered=True) + backward_operator = functools.partial(ifft2, centered=True) + # Configs + loss_config = LossConfig(losses=[FunctionConfig(loss) for loss in loss_fns]) + training_config = TrainingConfig(loss=loss_config) + validation_config = ValidationConfig(crop=None) + model_config = KSpaceDomainMRIViT2DConfig( + patch_size=patch_size, + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + compute_per_coil=compute_per_coil, + ) + config = DefaultConfig(training=training_config, validation=validation_config, model=model_config) + # Models + model = KSpaceDomainMRIViT2D( + forward_operator, + backward_operator, + patch_size=model_config.patch_size, + embedding_dim=model_config.embedding_dim, + depth=model_config.depth, + num_heads=model_config.num_heads, + mlp_ratio=model_config.mlp_ratio, + qkv_bias=model_config.qkv_bias, + qk_scale=model_config.qk_scale, + use_gpsa=model_config.use_gpsa, + locality_strength=model_config.locality_strength, + use_pos_embedding=model_config.use_pos_embedding, + normalized=model_config.normalized, + compute_per_coil=model_config.compute_per_coil, + ) + sensitivity_model = torch.nn.Conv2d(2, 2, kernel_size=1) + # Define engine + engine = KSpaceDomainMRIViT2DEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 2 + # Test _do_iteration function with a single data batch + data = create_sample( + shape, + sampling_mask=torch.from_numpy(np.random.randn(1, 1, shape[2], shape[3], 1)).bool(), + target=torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3])).float(), + ) + loss_fns = engine.build_loss() + out = engine._do_iteration(data, loss_fns) + out.output_image.shape == (shape[0],) + tuple(shape[2:-1]) + + +@pytest.mark.parametrize( + "shape", + [(2, 3, 4, 10, 16, 2), (1, 11, 8, 12, 16, 2)], +) +@pytest.mark.parametrize( + "loss_fns", + [ + [ + "l1_loss", + "snr_loss", + "kspace_nmae_loss", + "ssim_3d_loss", + ] + ], +) +@pytest.mark.parametrize( + "patch_size", + [(4, 8, 10)], +) +@pytest.mark.parametrize( + "embedding_dim", + [8], +) +@pytest.mark.parametrize( + "depth", + [4], +) +@pytest.mark.parametrize( + "num_heads", + [6], +) +@pytest.mark.parametrize( + "mlp_ratio", + [4.0], +) +@pytest.mark.parametrize( + "qkv_bias", + [False], +) +@pytest.mark.parametrize( + "qk_scale", + [None], +) +@pytest.mark.parametrize( + "use_gpsa", + [False], +) +@pytest.mark.parametrize( + "locality_strength", + [1.0], +) +@pytest.mark.parametrize( + "use_pos_embedding", + [True], +) +@pytest.mark.parametrize( + "normalized", + [False], +) +def test_image_vit3d_engine( + shape, + loss_fns, + patch_size, + embedding_dim, + depth, + num_heads, + mlp_ratio, + qkv_bias, + qk_scale, + use_gpsa, + locality_strength, + use_pos_embedding, + normalized, +): + # Operators + forward_operator = functools.partial(fft2, centered=True) + backward_operator = functools.partial(ifft2, centered=True) + # Configs + loss_config = LossConfig(losses=[FunctionConfig(loss) for loss in loss_fns]) + training_config = TrainingConfig(loss=loss_config) + validation_config = ValidationConfig(crop=None) + model_config = ImageDomainMRIViT3DConfig( + patch_size=patch_size, + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + ) + config = DefaultConfig(training=training_config, validation=validation_config, model=model_config) + # Models + model = ImageDomainMRIViT3D( + forward_operator, + backward_operator, + patch_size=model_config.patch_size, + embedding_dim=model_config.embedding_dim, + depth=model_config.depth, + num_heads=model_config.num_heads, + mlp_ratio=model_config.mlp_ratio, + qkv_bias=model_config.qkv_bias, + qk_scale=model_config.qk_scale, + use_gpsa=model_config.use_gpsa, + locality_strength=model_config.locality_strength, + use_pos_embedding=model_config.use_pos_embedding, + normalized=model_config.normalized, + ) + sensitivity_model = torch.nn.Conv2d(2, 2, kernel_size=1) + # Define engine + engine = ImageDomainMRIViT3DEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 3 + # Test _do_iteration function with a single data batch + data = create_sample( + shape, + sampling_mask=torch.from_numpy(np.random.randn(1, 1, 1, shape[3], shape[4], 1)).bool(), + target=torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3], shape[4])).float(), + scaling_factor=torch.ones(shape[0]), + ) + loss_fns = engine.build_loss() + out = engine._do_iteration(data, loss_fns) + out.output_image.shape == (shape[0],) + tuple(shape[2:-1]) + + +@pytest.mark.parametrize( + "shape", + [(2, 3, 4, 10, 16, 2), (1, 11, 8, 12, 16, 2)], +) +@pytest.mark.parametrize( + "loss_fns", + [ + [ + "l1_loss", + "snr_loss", + "kspace_nmse_loss", + "ssim_3d_loss", + ] + ], +) +@pytest.mark.parametrize( + "patch_size", + [6], +) +@pytest.mark.parametrize( + "embedding_dim", + [12], +) +@pytest.mark.parametrize( + "depth", + [4], +) +@pytest.mark.parametrize( + "num_heads", + [6], +) +@pytest.mark.parametrize( + "mlp_ratio", + [2.0], +) +@pytest.mark.parametrize( + "qkv_bias", + [False], +) +@pytest.mark.parametrize( + "qk_scale", + [None], +) +@pytest.mark.parametrize( + "use_gpsa", + [True], +) +@pytest.mark.parametrize( + "locality_strength", + [1.0], +) +@pytest.mark.parametrize( + "use_pos_embedding", + [False], +) +@pytest.mark.parametrize( + "normalized", + [True], +) +@pytest.mark.parametrize( + "compute_per_coil", + [True, False], +) +def test_kspace_vit3d_engine( + shape, + loss_fns, + patch_size, + embedding_dim, + depth, + num_heads, + mlp_ratio, + qkv_bias, + qk_scale, + use_gpsa, + locality_strength, + use_pos_embedding, + normalized, + compute_per_coil, +): + # Operators + forward_operator = functools.partial(fft2, centered=True) + backward_operator = functools.partial(ifft2, centered=True) + # Configs + loss_config = LossConfig(losses=[FunctionConfig(loss) for loss in loss_fns]) + training_config = TrainingConfig(loss=loss_config) + validation_config = ValidationConfig(crop=None) + model_config = KSpaceDomainMRIViT3DConfig( + patch_size=patch_size, + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + compute_per_coil=compute_per_coil, + ) + config = DefaultConfig(training=training_config, validation=validation_config, model=model_config) + # Models + model = KSpaceDomainMRIViT3D( + forward_operator, + backward_operator, + patch_size=model_config.patch_size, + embedding_dim=model_config.embedding_dim, + depth=model_config.depth, + num_heads=model_config.num_heads, + mlp_ratio=model_config.mlp_ratio, + qkv_bias=model_config.qkv_bias, + qk_scale=model_config.qk_scale, + use_gpsa=model_config.use_gpsa, + locality_strength=model_config.locality_strength, + use_pos_embedding=model_config.use_pos_embedding, + normalized=model_config.normalized, + compute_per_coil=model_config.compute_per_coil, + ) + sensitivity_model = torch.nn.Conv2d(2, 2, kernel_size=1) + # Define engine + engine = KSpaceDomainMRIViT3DEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 3 + # Test _do_iteration function with a single data batch + data = create_sample( + shape, + sampling_mask=torch.from_numpy(np.random.randn(1, 1, 1, shape[3], shape[4], 1)).bool(), + target=torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3], shape[4])).float(), + scaling_factor=torch.ones(shape[0]), + ) + loss_fns = engine.build_loss() + out = engine._do_iteration(data, loss_fns) + out.output_image.shape == (shape[0],) + tuple(shape[2:-1])