diff --git a/GeneralRelativity/Interpolation.py b/GeneralRelativity/Interpolation.py index dc2e6f4..bc8f530 100644 --- a/GeneralRelativity/Interpolation.py +++ b/GeneralRelativity/Interpolation.py @@ -151,6 +151,8 @@ def __init__( learnable: bool = False, align_grids_with_lower_dim_values: bool = False, dtype: Type[torch.dtype] = torch.float, + device: str = 'cpu', + ): """ Initialize the Interp class. @@ -161,6 +163,8 @@ def __init__( num_channels (int): Number of channels in the input tensor. learnable (bool): If True, the interpolation parameters are learnable. align_grids_with_lower_dim_values (bool): If True, aligns grid points with lower-dimensional values. + dtype (torch.dtype): dtype of weights, defaults to float + device (str): defaults to 'cpu' """ self.num_points = num_points self.max_degree = max_degree @@ -214,7 +218,7 @@ def __init__( # Create a convolutional kernel with zeros kernel = torch.zeros( (num_channels, 1, kernel_size, kernel_size, kernel_size), dtype=dtype - ) + ).to(device) # Find the minimum index for displacements to adjust kernel indexing min_index = torch.min(displacements) @@ -228,7 +232,7 @@ def __init__( conv_layer = nn.Conv3d( num_channels, num_channels, kernel_size, groups=num_channels, bias=False - ) + ).to(device) conv_layer.weight = nn.Parameter(kernel) conv_layer.weight.requires_grad = learnable self.conv_layers.append(conv_layer)