Skip to content

Commit

Permalink
WIP: iter
Browse files Browse the repository at this point in the history
  • Loading branch information
thewtex committed Dec 23, 2024
1 parent e6c77e1 commit 1d575fe
Showing 1 changed file with 90 additions and 12 deletions.
102 changes: 90 additions & 12 deletions ngff_zarr/methods/_itkwasm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Tuple
from itertools import product

import numpy as np
from dask.array import concatenate, expand_dims, map_blocks, map_overlap, take
from dask.array import concatenate, expand_dims, map_blocks, map_overlap, take, stack, from_array

from ..ngff_image import NgffImage
from ._support import (
Expand Down Expand Up @@ -181,9 +182,7 @@ def _downsample_itkwasm(

# Compute overlap for Gaussian blurring for all blocks
is_vector = previous_image.dims[-1] == "c"
block_0_image = itkwasm.image_from_array(
np.ones_like(block_0_input), is_vector=is_vector
)
block_0_image = itkwasm.image_from_array(np.ones_like(block_0_input), is_vector=is_vector)
input_spacing = [previous_image.scale[d] for d in spatial_dims]
block_0_image.spacing = input_spacing
input_origin = [previous_image.translation[d] for d in spatial_dims]
Expand Down Expand Up @@ -214,25 +213,18 @@ def _downsample_itkwasm(
block_output.size[dim] == computed_size[dim]
for dim in range(block_output.data.ndim)
)
breakpoint()
output_chunks = list(previous_image.data.chunks)
dims = list(previous_image.dims)
output_chunks_start = 0
while dims[output_chunks_start] not in _spatial_dims:
output_chunks_start += 1
output_chunks = output_chunks[output_chunks_start:]
# if "t" in previous_image.dims:
# dims = list(previous_image.dims)
# t_index = dims.index("t")
# output_chunks.pop(t_index)
for i, c in enumerate(output_chunks):
output_chunks[i] = [
block_output.data.shape[i],
] * len(c)
# Compute output size for block N-1
block_neg1_image = itkwasm.image_from_array(
np.ones_like(block_neg1_input), is_vector=is_vector
)
block_neg1_image = itkwasm.image_from_array(np.ones_like(block_neg1_input), is_vector=is_vector)
block_neg1_image.spacing = input_spacing
block_neg1_image.origin = input_origin
block_output = downsample_bin_shrink(
Expand All @@ -251,6 +243,92 @@ def _downsample_itkwasm(
output_chunks[i] = tuple(output_chunks[i])
output_chunks = tuple(output_chunks)

non_spatial_dims = [d for d in dims if d not in _spatial_dims]
if "c" in non_spatial_dims and dims[-1] == "c":
non_spatial_dims.pop("c")

# We'll iterate over each index for the non-spatial dimensions, run the desired
# map_overlap, and aggregate the outputs into a final result.

block_shape = [c[0] for c in previous_image.data.chunks]
# Determine the size for each non-spatial dimension
non_spatial_shapes = [
block_shape[dims.index(dim)] for dim in non_spatial_dims
]

# Collect results for each sub-block
aggregated_blocks = []
for idx in product(*(range(s) for s in non_spatial_shapes)):
# Build the slice object for indexing
slice_obj = []
non_spatial_index = 0
for dim in dims:
if dim in non_spatial_dims:
# Take a single index (like "t=0,1,...") for the non-spatial dimension
slice_obj.append(idx[non_spatial_index])
non_spatial_index += 1
else:
# Keep full slice for spatial/channel dims
slice_obj.append(slice(None))

# Extract the sub-block data for the chosen index from the non-spatial dims
sub_block_data = previous_image.data[tuple(slice_obj)]

downscaled_sub_block = map_overlap(
_itkwasm_blur_and_downsample,
sub_block_data,
shrink_factors=shrink_factors,
kernel_radius=kernel_radius,
smoothing=smoothing,
dtype=dtype,
depth=dict(enumerate(np.flip(kernel_radius))), # overlap is in tzyx
boundary="nearest",
trim=False, # Overlapped region is trimmed in blur_and_downsample to output size
chunks=output_chunks,
)
# sub_block_image = itkwasm.image_from_array(
# sub_block_data,
# is_vector=is_vector # or as needed for your pipeline
# )
# sub_block_image.spacing = input_spacing
# sub_block_image.origin = input_origin

# # Run your map_overlap or other downsampling operation on the sub_block
# # (e.g., downsample_bin_shrink, gaussian, etc.)
# sub_block_output = downsample_bin_shrink(
# sub_block_image,
# shrink_factors,
# information_only=False
# )

# Collect the result for later aggregation
aggregated_blocks.append(downscaled_sub_block)
downscaled_array = da.empty(downscaled_sub_block.shape)
blocks_dask = [from_array(block, chunks=block.shape) for block in aggregated_blocks]
final_dask_array = stack(blocks_dask, axis=0)

# At this point you have a list (aggregated_blocks) of processed sub-blocks.
# You can stitch/concat them back together along the non-spatial dimensions.
# For example, you can shape them into a single array if desired:
# (Rebuild the final data array in the same shape as the non-spatial dims + new spatial dims)
//
// final_result = np.empty(...)
// fill final_result with aggregated_blocks in correct order
//
// The rest of the code can remain unchanged, or you can assign final_result to your final image data
// ...existing code...
// ...existing code...
import dask.array as da

# Suppose we have a list of processed sub-blocks in aggregated_blocks
# Each item is a NumPy array from sub_block_output.data

blocks_dask = [da.from_array(block, chunks=block.shape) for block in aggregated_blocks]

# Combine them along the first axis (or another axis as needed)

# Now final_dask_array is a single dask array representing all sub-blocks
# ...existing code...
if "t" in previous_image.dims:
all_timepoints = []
for timepoint in range(previous_image.data.shape[t_index]):
Expand Down

0 comments on commit 1d575fe

Please sign in to comment.