Skip to content

Commit

Permalink
BROKEN stash initial work on new Channel model (ref #386)
Browse files Browse the repository at this point in the history
  • Loading branch information
tcompa committed Jun 8, 2023
1 parent 9cc9078 commit ed2205c
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 48 deletions.
82 changes: 47 additions & 35 deletions fractal_tasks_core/lib_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,30 @@
"""
import logging
from typing import Any
from typing import Dict
from typing import List
from typing import Sequence
from typing import Optional

import zarr
from pydantic import BaseModel


class ChannelWindow(BaseModel):
min: str
max: str
start: Optional[str]
end: Optional[str]


class Channel(BaseModel):
wavelength_id: str
label: Optional[str]
index: Optional[int]
active: bool = True
coefficient: int = 1
colormap: Optional[str]
family: str = "linear"
inverted: bool = False
window: Optional[ChannelWindow]


class ChannelNotFoundError(ValueError):
Expand All @@ -31,19 +50,11 @@ class ChannelNotFoundError(ValueError):
pass


def validate_allowed_channel_input(allowed_channels: Sequence[Dict[str, Any]]):
def validate_allowed_channel_input(allowed_channels: List[Channel]):
"""
Check that (1) each channel has a wavelength_id key, and (2) the
wavelength_id values are unique.
Check that the `wavelength_id` values are unique across channels
"""
try:
wavelength_ids = [c["wavelength_id"] for c in allowed_channels]
except KeyError as e:
raise KeyError(
"Missing wavelength_id key in some channel.\n"
f"{allowed_channels=}\n"
f"Original error: {str(e)}"
)
wavelength_ids = [c.wavelength_id for c in allowed_channels]
if len(set(wavelength_ids)) < len(wavelength_ids):
raise ValueError(
f"Non-unique wavelength_id's in {wavelength_ids}\n"
Expand Down Expand Up @@ -73,16 +84,16 @@ def check_well_channel_labels(*, well_zarr_path: str) -> None:

# For each pair of channel-labels lists, verify they do not overlap
for ind_1, channels_1 in enumerate(list_of_channel_lists):
labels_1 = set([c["label"] for c in channels_1])
labels_1 = set([c.label for c in channels_1])
for ind_2 in range(ind_1):
channels_2 = list_of_channel_lists[ind_2]
labels_2 = set([c["label"] for c in channels_2])
labels_2 = set([c.label for c in channels_2])
intersection = labels_1 & labels_2
if intersection:
hint = (
"Are you parsing fields of view into separate OME-Zarr"
" images? This could lead to non-unique channel labels"
", and then could be the reason of the error"
"Are you parsing fields of view into separate OME-Zarr "
"images? This could lead to non-unique channel labels, "
"and then could be the reason of the error"
)
raise ValueError(
"Non-unique channel labels\n"
Expand All @@ -92,7 +103,7 @@ def check_well_channel_labels(*, well_zarr_path: str) -> None:

def get_channel_from_image_zarr(
*, image_zarr_path: str, label: str = None, wavelength_id: str = None
) -> Dict[str, Any]:
) -> Channel:
"""
Extract a channel from OME-NGFF zarr attributes
Expand All @@ -112,20 +123,23 @@ def get_channel_from_image_zarr(
return channel


def get_omero_channel_list(*, image_zarr_path: str) -> List[Dict[str, Any]]:
def get_omero_channel_list(*, image_zarr_path: str) -> List[Channel]:
"""
Extract the list of channels from OME-NGFF zarr attributes
:param image_zarr_path: Path to an OME-NGFF image zarr group
:returns: A list of channel dictionaries
"""
group = zarr.open_group(image_zarr_path, mode="r+")
return group.attrs["omero"]["channels"]
channels_dicts = group.attrs["omero"]["channels"]
# FIXME what is the type of channels_dicts??
channels = [Channel(**c) for c in channels_dicts]
return channels


def get_channel_from_list(
*, channels: Sequence[Dict], label: str = None, wavelength_id: str = None
) -> Dict[str, Any]:
*, channels: List[Channel], label: str = None, wavelength_id: str = None
) -> Channel:
"""
Find matching channel in a list
Expand All @@ -147,16 +161,14 @@ def get_channel_from_list(
matching_channels = [
c
for c in channels
if (
c["label"] == label and c["wavelength_id"] == wavelength_id
)
if (c.label == label and c.wavelength_id == wavelength_id)
]
else:
matching_channels = [c for c in channels if c["label"] == label]
matching_channels = [c for c in channels if c.label == label]
else:
if wavelength_id:
matching_channels = [
c for c in channels if c["wavelength_id"] == wavelength_id
c for c in channels if c.wavelength_id == wavelength_id
]
else:
raise ValueError(
Expand All @@ -178,16 +190,16 @@ def get_channel_from_list(
raise ValueError(f"Inconsistent set of channels: {channels}")

channel = matching_channels[0]
channel["index"] = channels.index(channel)
channel.index = channels.index(channel)
return channel


def define_omero_channels(
*,
channels: Sequence[Dict[str, Any]],
channels: List[Channel],
bit_depth: int,
label_prefix: str = None,
) -> List[Dict[str, Any]]:
) -> List[dict[str, Any]]:
"""
Update a channel list to use it in the OMERO/channels metadata
Expand All @@ -211,11 +223,11 @@ def define_omero_channels(
default_colormaps = ["00FFFF", "FF00FF", "FFFF00"]

for channel in channels:
wavelength_id = channel["wavelength_id"]
wavelength_id = channel.wavelength_id

# Always set a label
try:
label = channel["label"]
label = channel.label
except KeyError:
default_label = wavelength_id
if label_prefix:
Expand All @@ -227,7 +239,7 @@ def define_omero_channels(

# Set colormap attribute. If not specificed, use the default ones (for
# the first three channels) or gray
colormap = channel.get("colormap", None)
colormap = channel.colormap
if colormap is None:
try:
colormap = default_colormaps.pop()
Expand All @@ -239,7 +251,7 @@ def define_omero_channels(
"min": 0,
"max": 2**bit_depth - 1,
}
if "start" in channel.keys() and "end" in channel.keys():
if "start" in channel.dict().keys() and "end" in channel.dict().keys():
window["start"] = channel["start"]
window["end"] = channel["end"]

Expand Down
11 changes: 8 additions & 3 deletions fractal_tasks_core/tasks/create_ome_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
import pandas as pd
import zarr
from anndata.experimental import write_elem
from devtools import debug

import fractal_tasks_core
from fractal_tasks_core.lib_channels import Channel
from fractal_tasks_core.lib_channels import check_well_channel_labels
from fractal_tasks_core.lib_channels import define_omero_channels
from fractal_tasks_core.lib_channels import validate_allowed_channel_input
Expand All @@ -51,7 +53,7 @@ def create_ome_zarr(
metadata: Dict[str, Any],
image_extension: str = "tif",
image_glob_patterns: Optional[list[str]] = None,
allowed_channels: Sequence[Dict[str, Any]],
allowed_channels: List[Channel],
num_levels: int = 2,
coarsening_xy: int = 2,
metadata_table: str = "mrf_mlf",
Expand Down Expand Up @@ -108,6 +110,9 @@ def create_ome_zarr(
dict_plate_prefixes: Dict[str, Any] = {}

# Preliminary checks on allowed_channels argument
allowed_channels_raw = allowed_channels.copy()
allowed_channels = [Channel(**c) for c in allowed_channels_raw]
debug(allowed_channels)
validate_allowed_channel_input(allowed_channels)

for in_path_str in input_paths:
Expand Down Expand Up @@ -188,7 +193,7 @@ def create_ome_zarr(

# Check that all channels are in the allowed_channels
allowed_wavelength_ids = [
channel["wavelength_id"] for channel in allowed_channels
channel.wavelength_id for channel in allowed_channels
]
if not set(actual_wavelength_ids).issubset(set(allowed_wavelength_ids)):
msg = "ERROR in create_ome_zarr\n"
Expand All @@ -201,7 +206,7 @@ def create_ome_zarr(
actual_channels = [
channel
for channel in allowed_channels
if channel["wavelength_id"] in actual_wavelength_ids
if channel.wavelength_id in actual_wavelength_ids
]

zarrurls: Dict[str, List[str]] = {"plate": [], "well": [], "image": []}
Expand Down
22 changes: 12 additions & 10 deletions tests/test_unit_channels_addressing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,36 @@

from devtools import debug

from fractal_tasks_core.lib_channels import Channel
from fractal_tasks_core.lib_channels import get_channel_from_list


def test_get_channel(testdata_path: Path):
with (testdata_path / "omero/channels_list.json").open("r") as f:
omero_channels = json.load(f)
omero_channels_dict = json.load(f)
omero_channels = [Channel(**c) for c in omero_channels_dict]
debug(omero_channels)

channel = get_channel_from_list(channels=omero_channels, label="label_1")
debug(channel)
assert channel["label"] == "label_1"
assert channel["wavelength_id"] == "wavelength_id_1"
assert channel["index"] == 0
assert channel.label == "label_1"
assert channel.wavelength_id == "wavelength_id_1"
assert channel.index == 0

channel = get_channel_from_list(
channels=omero_channels, wavelength_id="wavelength_id_2"
)
debug(channel)
assert channel["label"] == "label_2"
assert channel["wavelength_id"] == "wavelength_id_2"
assert channel["index"] == 1
assert channel.label == "label_2"
assert channel.wavelength_id == "wavelength_id_2"
assert channel.index == 1

channel = get_channel_from_list(
channels=omero_channels,
label="label_2",
wavelength_id="wavelength_id_2",
)
debug(channel)
assert channel["label"] == "label_2"
assert channel["wavelength_id"] == "wavelength_id_2"
assert channel["index"] == 1
assert channel.label == "label_2"
assert channel.wavelength_id == "wavelength_id_2"
assert channel.index == 1

0 comments on commit ed2205c

Please sign in to comment.