Skip to content

Commit

Permalink
added such that interpolation can choose device
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHelfer committed Feb 1, 2024
1 parent fe28e62 commit ffed913
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions GeneralRelativity/Interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def __init__(
learnable: bool = False,
align_grids_with_lower_dim_values: bool = False,
dtype: Type[torch.dtype] = torch.float,
device: str = 'cpu',

):
"""
Initialize the Interp class.
Expand All @@ -161,6 +163,8 @@ def __init__(
num_channels (int): Number of channels in the input tensor.
learnable (bool): If True, the interpolation parameters are learnable.
align_grids_with_lower_dim_values (bool): If True, aligns grid points with lower-dimensional values.
dtype (torch.dtype): dtype of weights, defaults to float
device (str): defaults to 'cpu'
"""
self.num_points = num_points
self.max_degree = max_degree
Expand Down Expand Up @@ -214,7 +218,7 @@ def __init__(
# Create a convolutional kernel with zeros
kernel = torch.zeros(
(num_channels, 1, kernel_size, kernel_size, kernel_size), dtype=dtype
)
).to(device)

# Find the minimum index for displacements to adjust kernel indexing
min_index = torch.min(displacements)
Expand All @@ -228,7 +232,7 @@ def __init__(

conv_layer = nn.Conv3d(
num_channels, num_channels, kernel_size, groups=num_channels, bias=False
)
).to(device)
conv_layer.weight = nn.Parameter(kernel)
conv_layer.weight.requires_grad = learnable
self.conv_layers.append(conv_layer)
Expand Down

0 comments on commit ffed913

Please sign in to comment.