Skip to content

Commit

Permalink
fixed interpolation code
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHelfer committed Jan 26, 2024
1 parent 27e4d1d commit 99febbb
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions tests/test_interpolations.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ def test_interpolation_on_grid():
dx (float): Differential step to scale the grid positions.
"""
for centering in [True, False]:
for num_points in [4, 6, 8]:
tol = 1e-10
for num_points in [4,6,8]:
tol = 1e-9
channels = 25
interpolation = interp(
num_points=num_points,
max_degree=num_points//2,
max_degree=num_points // 2,
num_channels=channels,
learnable=False,
align_grids_with_lower_dim_values=centering,
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_interpolation_on_grid():
positions = interpolation.get_postion(x)

# Preparing ground truth for comparison
ghosts = int(math.ceil(6 / 2))
ghosts = int(math.ceil(num_points / 2))
shape = x.shape
ground_truth = torch.zeros(
shape[0],
Expand All @@ -133,14 +133,13 @@ def test_interpolation_on_grid():
pos = dx * (positions[i, j, k])
ground_truth[:, :, i, j, k] = sinusoidal_function(*pos)

print(num_points)
# Comparing interpolated and ground truth values
assert (
torch.mean(torch.abs(interpolated - ground_truth))
) < tol
#assert((torch.mean(torch.abs(interpolated - ground_truth))))
assert (torch.mean(torch.abs(interpolated - ground_truth))) < tol

# Comparing old and new interpolation
assert(torch.mean(torch.abs(interpolated_old - interpolated))< tol)

assert torch.mean(torch.abs(interpolated_old - interpolated)) < tol


if __name__ == "__main__":
Expand Down

0 comments on commit 99febbb

Please sign in to comment.