Source code for pymaid.stack

"""Access to image data as ``xarray.DataArray``s.

CATMAID's image source conventions are documented here
https://catmaid.readthedocs.io/en/stable/tile_sources.html
"""
from __future__ import annotations
from functools import wraps
from io import BytesIO
from typing import Any, Literal, Optional, Sequence, Type, Union, Dict
from abc import ABC
import sys

import numpy as np
from numpy.typing import DTypeLike, ArrayLike
import json
import requests

from pymaid.client import CatmaidInstance

from . import utils
from .fetch.stack import (
    StackInfo,
    MirrorInfo,
    get_stacks,
    get_stack_info,
    get_mirror_info,
)

try:
    import aiohttp
    from dask import array as da
    import imageio.v3 as iio
    import xarray as xr
    import zarr
    from zarr.storage import BaseStore
except ImportError as e:
    raise ImportError(
        'Optional dependencies for stack viewing are not available. '
        'Make sure the appropriate extra is installed: `pip install pymaid[stack]`. '
        f'Original error: "{str(e)}"'
    )

Dimension = Literal["x", "y", "z"]
# Orientation = Literal["xy", "xz", "zy"]
HALF_PX = 0.5
ENDIAN = "<" if sys.byteorder == "little" else ">"


@wraps(print)
def eprint(*args, **kwargs):
    """Thin wrapper around ``print`` which defaults to stderr"""
    kwargs.setdefault("file", sys.stderr)
    return print(*args, **kwargs)


def select_cli(prompt: str, options: Dict[int, str]) -> Optional[int]:
    out = None
    eprint(prompt)
    for k, v in sorted(options.items()):
        eprint(f"\t{k}.\t{v}")
    p = "Type number and press enter (empty to cancel): "
    while out is None:
        result_str = input(p).strip()
        if not result_str:
            break
        try:
            result = int(result_str)
        except ValueError:
            eprint("Not an integer, try again")
            continue
        if result not in options:
            eprint("Not a valid option, try again")
            continue
        out = result
    return out


def to_array(
    coord: Union[Dict[Dimension, Any], ArrayLike],
    dtype: DTypeLike = np.float64,
    order: Sequence[Dimension] = ("z", "y", "x"),
) -> np.ndarray:
    if isinstance(coord, dict):
        coord = [coord[d] for d in order]
    return np.asarray(coord, dtype=dtype)


class ImageIoStore(BaseStore, ABC):
    """
    Must include instance variable 'fmt',
    which is a format string with variables:
    image_base, zoom_level, file_extension, row, col, slice_idx
    """

    tile_source_type: int
    fmt: str
    _writeable = False
    _erasable = False
    _listable = False

    def __init__(
        self,
        stack_info: StackInfo,
        mirror_info: MirrorInfo,
        zoom_level: int,
        session: Union[requests.Session, CatmaidInstance, None] = None,
    ) -> None:
        if mirror_info.tile_source_type != self.tile_source_type:
            raise ValueError("Mismatched tile source type")
        self.stack_info = stack_info
        self.mirror_info = mirror_info
        self.zoom_level = zoom_level

        if isinstance(session, CatmaidInstance):
            self.session = session._session
        elif isinstance(session, requests.Session):
            self.session = session
        elif session is None:
            session = requests.Session()

        brok_sl = {int(k): int(k) + v for k, v in self.stack_info.broken_slices.items()}
        self.broken_slices = dict()
        for k, v in brok_sl.items():
            while v in brok_sl:
                v = brok_sl[v]
            self.broken_slices[k] = v

        order = self.stack_info.orientation.full_orientation(reverse=True)
        self.metadata_bytes = json.dumps(
            {
                "zarr_format": 2,
                "shape": to_array(stack_info.dimension, int, order).tolist(),
                "chunks": [1, mirror_info.tile_height, mirror_info.tile_width],
                "dtype": ENDIAN + "u1",
                "compressor": None,
                "fill_value": 0,
                "order": "C",
                "filters": None,
                "dimension_separator": ".",
            }
        ).encode()
        self.attrs_bytes = json.dumps(
            {
                "stack_info": self.stack_info.to_jso(),
                "mirror_info": self.mirror_info.to_jso(),
                "scale_level": self.zoom_level,
            }
        ).encode()

        self.empty = np.zeros(
            (
                self.mirror_info.tile_width,
                self.mirror_info.tile_height,
                1,
            ),
            "uint8",
        ).tobytes()

    def _format_url(self, row: int, col: int, slice_idx: int) -> str:
        return self.fmt.format(
            image_base=self.mirror_info.image_base,
            zoom_level=self.zoom_level,
            slice_idx=slice_idx,
            row=row,
            col=col,
            file_extension=self.mirror_info.file_extension,
        )

    def __delitem__(self, __key) -> None:
        raise NotImplementedError()

    def __iter__(self):
        raise NotImplementedError()

    def __len__(self) -> int:
        raise NotImplementedError()

    def __setitem__(self, __key, __value) -> None:
        raise NotImplementedError()

    def _resolve_broken_slices(self, slice_idx: int) -> int:
        return self.broken_slices.get(slice_idx, slice_idx)

    def __getitem__(self, key):
        last = key.split("/")[-1]
        if last == ".zarray":
            return self.metadata_bytes
        elif last == ".zattrs":
            return self.attrs_bytes

        # todo: check order
        slice_idx, row, col = (int(i) for i in last.split("."))
        slice_idx = self._resolve_broken_slices(slice_idx)

        url = self._format_url(row, col, slice_idx)
        response = self.session.get(url)
        if response.status_code == 404:
            return self.empty
        response.raise_for_status()
        ext = self.mirror_info.file_extension.split("?")[0]
        if not ext.startswith("."):
            ext = "." + ext
        arr = iio.imread(
            BytesIO(response.content),
            extension=ext,
            mode="L",
        )
        return arr.tobytes()

    def to_zarr_array(self) -> zarr.Array:
        return zarr.open_array(self, "r")

    def to_dask_array(self) -> xr.DataArray:
        # todo: transpose?
        as_zarr = self.to_zarr_array()
        return da.from_zarr(as_zarr)

    def to_xarray(self) -> xr.DataArray:
        as_dask = self.to_dask_array()
        return xr.DataArray(
            as_dask,
            coords=self.stack_info.get_coords(self.zoom_level),
            dims=self.stack_info.orientation.full_orientation(True),
        )


class TileStore1(ImageIoStore):
    """File-based image stack."""
    tile_source_type = 1
    fmt = "{image_base}{slice_idx}/{row}_{col}_{zoom_level}.{file_extension}"


class TileStore4(ImageIoStore):
    """File-based image stack with zoom level directories."""
    tile_source_type = 4
    fmt = "{image_base}{slice_idx}/{zoom_level}/{row}_{col}.{file_extension}"


class TileStore5(ImageIoStore):
    """Directory-based image stack with zoom, z, and row directories."""
    tile_source_type = 5
    fmt = "{image_base}{zoom_level}/{slice_idx}/{row}/{col}.{file_extension}"


# class TileStore10(ImageIoStore):
#     """H2N5 tile stack."""
#     tile_source_type = 10
#     fmt = "{image_base}.{file_extension}"

#     def _format_url(self, row: int, col: int, slice_idx: int) -> str:
#         s = self.fmt.format(
#             image_base=self.mirror_info.image_base,
#             # todo: manually change quality?
#             file_extension=self.mirror_info.file_extension,
#         )
#         s = s.replace("%SCALE_DATASET%", f"s{self.zoom_level}")
#         s = s.replace("%AXIS_0%", str(col * self.mirror_info.tile_width))
#         s = s.replace("%AXIS_1%", str(row * self.mirror_info.tile_height))
#         s = s.replace("%AXIS_2%", str(slice_idx))
#         return s


tile_stores: Dict[int, Type[ImageIoStore]] = {
    t.tile_source_type: t
    for t in [
        TileStore1,
        TileStore4,
        TileStore5,
        # TileStore10
    ]
}
source_client_types = {k: (requests.Session,) for k in tile_stores}
source_client_types[11] = (aiohttp.ClientSession,)

Client = Union[requests.Session, aiohttp.ClientSession]


def select_stack(remote_instance=None) -> Optional[int]:
    """"""
    stacks = get_stacks(remote_instance)
    options = {s.id: s.title for s in stacks}
    return select_cli("Select stack:", options)


[docs] class Stack: """Class representing a CATMAID stack of images. Stacks are usually a scale pyramid. This class can, for certain stack mirror types, allow access to individual scale levels as arrays which can be queried in voxel or world coordinates. HTTP requests to fetch stack data are often configured differently for different stack mirrors and tile source types. For most non-public mirrors, you will need to set the object to make these requests: see the ``my_stack.set_mirror_session()`` method; if you just need to set HTTP Basic authentication headers, see the ``my_stack.set_mirror_auth()`` convenience method. See the ``my_stack.get_scale()`` method for getting an `xarray.DataArray <https://docs.xarray.dev/en/stable/generated/xarray.DataArray.html#xarray.DataArray>`_ representing that scale level. This can be queried in stack/ voxel or project/ world coordinates, efficiently sliced and transposed etc.. """
[docs] def __init__( self, stack_info: StackInfo, mirror: Optional[Union[int, str]] = None, ): """The :func:`Stack.from_catmaid` constructor may be more convenient. Parameters ---------- stack_info : StackInfo mirror_id : Optional[int], optional """ self.stack_info = stack_info self.mirror_info: Optional[MirrorInfo] = None self.mirror_session: Dict[int, Any] = dict() if mirror is not None: self.set_mirror(mirror)
def set_mirror_auth(self, mirror: Union[int, str, None, MirrorInfo], http_user: str, http_password: str): """Set the HTTP Basic credentials for a particular stack mirror. This will replace any other session configured for that mirror. For more fine-grained control (e.g. setting other headers), or to re-use the session object from a ``CatmaidInstance``, see ``my_stack.set_mirror_session()``. Parameters ---------- mirror : Union[int, str, None, MirrorInfo] Mirror, as MirrorInfo, intger ID, string title, or None (use default) http_user : str HTTP Basic username http_password : str HTTP Basic password Raises ------ ValueError If the given mirror is not supported. """ minfo = self._get_mirror_info(mirror) if minfo.tile_source_type == 11: s = aiohttp.ClientSession(auth=aiohttp.BasicAuth(http_user, http_password)) return self.set_mirror_session(mirror, s) elif minfo.tile_source_type in source_client_types: s = requests.Session() s.auth = (http_user, http_password) return self.set_mirror_session(mirror, s) else: raise ValueError("Mirror's tile source type is unsupported: %s", minfo.tile_source_type) def set_mirror_session( self, mirror: Union[int, str, None, MirrorInfo], session: Client, ): """Set functions which construct the session for fetching image data, per mirror. For most tile stacks, this is a `requests.Session <https://requests.readthedocs.io/en/latest/api/#requests.Session>`_. See ``get_remote_instance_session`` to use the session from a given ``CatmaidInstance`` (including the global). For N5 (tile source 11), this is a `aiohttp.ClientSession <https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession>`_. Parameters ---------- mirror : Union[int, str, None] Mirror, as integer ID, string name, or None to use the one defined on the class. session : Union[requests.Session, aiohttp.ClientSession] HTTP session of the appropriate type. For example, to re-use the ``requests.Session`` from the global ``CatmaidInstance`` for mirror with ID 1, use ``my_stack.set_mirror_instance(1, get_remote_instance_session())``. To use HTTP basic auth for an N5 stack mirror (tile source 11) with ID 2, use ``my_stack.set_mirror_instance_factor(2, aiohttp.ClientSession(auth=aiohttp.BasicAuth("myusername", "mypassword")))``. """ minfo = self._get_mirror_info(mirror) self.mirror_session[minfo.id] = session @classmethod def select_from_catmaid(cls, remote_instance=None): """Interactively select a stack and mirror from those available. Parameters ---------- remote_instance : CatmaidInstance, optional By default global. """ stacks = get_stacks(remote_instance) options = {s.id: s.title for s in stacks} sid = select_cli("Select stack:", options) if not sid: return None out = cls.from_catmaid(sid, remote_instance=remote_instance) out.select_mirror() return out @classmethod def from_catmaid( cls, stack: Union[str, int], mirror: Optional[Union[int, str]] = None, remote_instance=None ): """Fetch relevant data from CATMAID and build a Stack. Parameters ---------- stack : Union[str, int] Integer stack ID or string stack title. mirror : Optional[int, str], optional Integer mirror ID or string mirror title, by default None remote_instance : CatmaidInstance, optional By default global. """ sinfo = get_stack_info(stack, remote_instance) return cls(sinfo, mirror) def _get_mirror_info(self, mirror: Union[int, str, None, MirrorInfo] = None) -> MirrorInfo: if isinstance(mirror, MirrorInfo): return mirror if mirror is None: if self.mirror_info is None: raise ValueError("No default mirror ID set") return self.mirror_info return get_mirror_info(self.stack_info, mirror) def set_mirror(self, mirror: Union[int, str]): """Set the mirror using its int ID or str title.""" self.mirror_info = self._get_mirror_info(mirror) def select_mirror(self): """Interactively select a mirror from those available. """ options = { m.id: m.title for m in self.stack_info.mirrors if m.tile_source_type in source_client_types } if not options: eprint("No mirrors with supported tile source type") return result = select_cli( f"Select mirror for stack '{self.stack_info.stitle}':", options, ) if result is not None: self.set_mirror(result) def _get_session(self, mirror_id: int, default: Optional[Any]=None): try: return self.mirror_session[mirror_id] except KeyError: if default is None: raise else: return default def get_scale( self, scale_level: int, mirror_id: Optional[int] = None ) -> xr.DataArray: """Get an xarray.DataArray representing the given scale level. Note that depending on the metadata available, missing scale levels may throw different errors. Parameters ---------- scale_level : int 0 for full resolution mirror_id : Optional[int], optional By default the one set on the class. Returns ------- xr.DataArray Can be queried in voxel or world space. Raises ------ ValueError Scale level does not exist, according to metadata NotImplementedError Unknown tile source type for this mirror """ mirror_info = self._get_mirror_info(mirror_id) if ( self.stack_info.num_zoom_levels > 0 and scale_level > self.stack_info.num_zoom_levels ): raise ValueError( f"Scale level {scale_level} does not exist " f"for stack {self.stack_info.sid} " f"with {self.stack_info.num_zoom_levels} stack levels" ) if mirror_info.tile_source_type in tile_stores: store_class = tile_stores[mirror_info.tile_source_type] session = self._get_session( mirror_info.id, requests.Session(), ) check_session_type(session, mirror_info.tile_source_type) store = store_class( self.stack_info, mirror_info, scale_level, session ) return store.to_xarray() elif mirror_info.tile_source_type == 11: return self._get_n5(mirror_info, scale_level) raise NotImplementedError( f"Tile source type {mirror_info.tile_source_type} not implemented" ) def _get_n5( self, mirror_info: MirrorInfo, scale_level: int, ) -> xr.DataArray: if mirror_info.tile_source_type != 11: raise ValueError("Mirror info not from an N5 tile source") formatted = mirror_info.image_base.replace( "%SCALE_DATASET%", f"s{scale_level}" ) *components, transpose_str = formatted.split("/") transpose = [int(t) for t in transpose_str.split("_")] container_comp = [] arr_comp = [] this = container_comp for comp in components: this.append(comp) if comp.lower().endswith(".n5"): this = arr_comp if not arr_comp: raise ValueError("N5 container must have '.n5' suffix") kwargs = dict() session = self._get_session(mirror_info.id, None) if session is not None: check_session_type(session, 11) kwargs["get_client"] = lambda: session store = zarr.N5FSStore("/".join(container_comp), **kwargs) container = zarr.open(store, "r") as_zarr = container["/".join(arr_comp)] # todo: check this transpose as_dask = da.from_zarr(as_zarr).transpose(transpose) return xr.DataArray( as_dask, coords=self.stack_info.get_coords(scale_level), dims=self.stack_info.orientation.full_orientation(True), )
def check_session_type(session, tile_source_type: int): expected = source_client_types[tile_source_type] if not isinstance(session, expected): raise ValueError( f"Incorrect HTTP client type for tile source {tile_source_type}. " f"Got {type(session)} but expected one of {expected}." ) def get_remote_instance_session(remote_instance: Optional[CatmaidInstance] = None): """Get the ``requests.Session`` from the given ``CatmaidInstance``. If ``None`` is given, use the global ``CatmaidInstance``. """ cm = utils._eval_remote_instance(remote_instance) return cm._session