from __future__ import annotations
import abc
import math
from typing import Final, ClassVar
import attrs
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.image import AxesImage
from matplotlib.transforms import CompositeGenericTransform
from typing_extensions import Self
from neuralib.atlas.data import DATA_SOURCE_TYPE, load_ccf_annotation, load_ccf_template, load_allensdk_annotation
from neuralib.atlas.util import PLANE_TYPE, ALLEN_CCF_10um_BREGMA
from neuralib.imglib.factory import ImageProcFactory
from neuralib.typing import PathLike
__all__ = [
'load_slice_view',
'AbstractSliceView',
'SlicePlane'
]
[docs]
def load_slice_view(source: DATA_SOURCE_TYPE,
plane_type: PLANE_TYPE,
*,
output_dir: PathLike | None = None,
allen_annotation_res: int = 10) -> AbstractSliceView:
"""
Load the mouse brain slice view
:param source: ``DATA_SOURCE_TYPE``. {'ccf_annotation', 'ccf_template', 'allensdk_annotation'}
:param plane_type: ``PLANE_TYPE``. {'coronal', 'sagittal', 'transverse'}
:param output_dir: Output directory for caching
:param allen_annotation_res: Volume resolution in um. default is 10 um
:return: :class:`AbstractSliceView`
"""
if source == 'ccf_annotation':
data = load_ccf_annotation(output_dir)
res = 10
elif source == 'ccf_template':
data = load_ccf_template(output_dir)
res = 10
elif source == 'allensdk_annotation':
data = load_allensdk_annotation(resolution=allen_annotation_res, output_dir=output_dir)
res = allen_annotation_res
else:
raise ValueError('')
return AbstractSliceView(source, plane_type, res, data)
[docs]
class AbstractSliceView(metaclass=abc.ABCMeta):
"""
SliceView ABC for different `plane type`
`Dimension parameters`:
AP = anterior-posterior
DV = dorsal-ventral
ML = medial-lateral
W = view width
H = view height
"""
REFERENCE_FROM: ClassVar[str] = ''
"""reference from which axis"""
source_type: Final[DATA_SOURCE_TYPE]
"""``DATA_SOURCE_TYPE``. {'ccf_annotation', 'ccf_template', 'allensdk_annotation'}"""
plane_type: Final[PLANE_TYPE]
"""`PLANE_TYPE``. {'coronal', 'sagittal', 'transverse'}"""
resolution: Final[int]
"""um/pixel"""
reference: Final[np.ndarray]
"""Array[float, [AP, DV, ML]]"""
grid_x: Final[np.ndarray]
"""Array[int, [W, H]]"""
grid_y: Final[np.ndarray]
"""Array[int, [W, H]]"""
[docs]
def __new__(cls, source_type: DATA_SOURCE_TYPE,
plane: PLANE_TYPE,
resolution: int,
reference: np.ndarray):
if plane == 'coronal':
return object.__new__(CoronalSliceView)
elif plane == 'sagittal':
return object.__new__(SagittalSliceView)
elif plane == 'transverse':
return object.__new__(TransverseSliceView)
else:
raise ValueError(f'invalid plane: {plane}')
[docs]
def __init__(self, source_type: DATA_SOURCE_TYPE,
plane: PLANE_TYPE,
resolution: int,
reference: np.ndarray):
"""
:param source_type: ``DATA_SOURCE_TYPE``. {'ccf_annotation', 'ccf_template', 'allensdk_annotation'}
:param plane: `PLANE_TYPE``. {'coronal', 'sagittal', 'transverse'}
:param resolution: um/pixel
:param reference: Array[uint16, [AP, DV, ML]]
"""
self.source_type = source_type
self.plane_type = plane
self.resolution = resolution
self.reference = reference
self.grid_y, self.grid_x = np.mgrid[0:self.height, 0:self.width]
self._check_attrs()
def _check_attrs(self):
if self.resolution == 10:
assert self.reference.shape == (1320, 800, 1140)
elif self.resolution == 25:
assert self.reference.shape == (528, 320, 456)
@property
def bregma(self) -> np.ndarray:
if self.resolution == 10:
return ALLEN_CCF_10um_BREGMA
raise NotImplementedError('')
@property
def n_ap(self) -> int:
"""number of slices along AP axis"""
return self.reference.shape[0]
@property
def n_dv(self) -> int:
"""number of slices along DV axis"""
return self.reference.shape[1]
@property
def n_ml(self) -> int:
"""number of slices along ML axis"""
return self.reference.shape[2]
@property
@abc.abstractmethod
def n_planes(self) -> int:
"""number of planes in a specific plane view"""
pass
@property
@abc.abstractmethod
def width(self) -> int:
"""width (pixel) in a specific plane view"""
pass
@property
@abc.abstractmethod
def height(self) -> int:
"""height (pixel) in a specific plane view"""
pass
@property
def width_mm(self) -> float:
"""width (um) in a specific plane view"""
return self.width * self.resolution / 1000
@property
def height_mm(self) -> float:
"""height (um) in a specific plane view"""
return self.height * self.resolution / 1000
@property
@abc.abstractmethod
def reference_point(self) -> int:
"""reference point in a specific plane view. aka, bregma plane index"""
pass
@property
@abc.abstractmethod
def project_index(self) -> tuple[int, int, int]:
"""plane(p), x, y of index order in (AP, DV, ML)
:return: (p, x, y)
"""
pass
[docs]
def plane_at(self, slice_index: int) -> SlicePlane:
return SlicePlane(slice_index, int(self.width // 2), int(self.height // 2), 0, 0, self)
[docs]
def offset(self, h: int, v: int) -> np.ndarray:
"""
:param h: horizontal plane diff to the center. right side positive.
:param v: vertical plane diff to the center. bottom side positive.
:return: (H, W) array
"""
x_frame = np.round(np.linspace(-h, h, self.width)).astype(int)
y_frame = np.round(np.linspace(-v, v, self.height)).astype(int)
return np.add.outer(y_frame, x_frame)
[docs]
def plane(self, offset: int | tuple[int, int, int] | np.ndarray) -> np.ndarray:
"""Get image plane.
:param offset: Array[int, height, width] or tuple (plane, dh, dv)
:return:
"""
if isinstance(offset, int):
offset = np.full_like((self.height, self.width), offset)
elif isinstance(offset, tuple):
offset = offset[0] + self.offset(offset[1], offset[2])
elif not isinstance(offset, np.ndarray):
raise TypeError(str(type(offset)))
offset[offset < 0] = 0
offset[offset > self.n_planes] = self.n_planes - 1
return self.reference[self.coor_on(offset, (self.grid_x, self.grid_y))]
[docs]
def coor_on(self, plane: np.ndarray,
o: tuple[np.ndarray, np.ndarray]) -> tuple[np.ndarray, ...]:
"""
map slice point (x, y) at plane *plane* back to volume point (ap, dv, ml)
:param plane: plane number of array
:param o: tuple of (x, y)
:return: (ap, dv, ml)
"""
pidx, xidx, yidx = self.project_index
ret = [0, 0, 0]
ret[pidx] = plane
ret[xidx] = o[0]
ret[yidx] = o[1]
return tuple(ret)
class CoronalSliceView(AbstractSliceView):
REFERENCE_FROM: ClassVar[str] = 'AP'
@property
def n_planes(self) -> int:
return self.n_ap
@property
def width(self) -> int:
return self.n_ml
@property
def height(self) -> int:
return self.n_dv
@property
def reference_point(self) -> int:
return int(self.bregma[0])
@property
def project_index(self) -> tuple[int, int, int]:
return 0, 2, 1
class SagittalSliceView(AbstractSliceView):
REFERENCE_FROM: ClassVar[str] = 'ML'
@property
def n_planes(self) -> int:
return self.n_ml
@property
def width(self) -> int:
return self.n_ap
@property
def height(self) -> int:
return self.n_dv
@property
def reference_point(self) -> int:
return int(self.bregma[2])
@property
def project_index(self) -> tuple[int, int, int]:
return 2, 0, 1 # p=ML, x=AP, y=DV
class TransverseSliceView(AbstractSliceView):
REFERENCE_FROM: ClassVar[str] = 'DV'
@property
def n_planes(self) -> int:
return self.n_dv
@property
def width(self) -> int:
return self.n_ml
@property
def height(self) -> int:
return self.n_ap
@property
def reference_point(self) -> int:
return int(self.bregma[1])
@property
def project_index(self) -> tuple[int, int, int]:
return 1, 2, 0
[docs]
@attrs.define
class SlicePlane:
"""2D Wrapper for a specific plane"""
slice_index: int
"""anchor index"""
ax: int
"""anchor x"""
ay: int
"""anchor y"""
dw: int
"""dw in um"""
dh: int
"""dh in um"""
view: AbstractSliceView
"""``AbstractSliceView``"""
@property
def image(self) -> np.ndarray:
return self.view.plane(self.plane_offset)
@property
def plane_offset(self) -> np.ndarray:
offset = self.view.offset(self.dw, self.dh)
return self.slice_index + offset - offset[self.ay, self.ax]
@property
def reference_value(self) -> float:
"""relative to reference point"""
factor = 1000 / self.view.resolution
return round((self.view.reference_point - self.slice_index) / factor, 2)
[docs]
def with_offset(self, dw: int, dh: int, debug: bool = False) -> Self:
if debug:
deg_x, deg_y = self._value_to_angle(dw, dh)
print(f'{dw=}, {dh=}')
print(f'{deg_x=}, {deg_y=}')
return attrs.evolve(self, dw=dw, dh=dh)
def _value_to_angle(self, dw: int, dh: int) -> tuple[float, float]:
"""delta value to degree"""
rx = math.atan(2 * dw / self.view.width)
ry = math.atan(2 * dh / self.view.height)
deg_x = np.rad2deg(rx)
deg_y = np.rad2deg(ry)
return deg_x, deg_y
[docs]
def with_angle_offset(self, deg_x: float, deg_y: float) -> Self:
"""
with degree offset
:param deg_x: degree in x axis (width)
:param deg_y: degree in y axis (height)
:return:
"""
rx = np.deg2rad(deg_x)
ry = np.deg2rad(deg_y)
dw = int(self.view.width * math.tan(rx) / 2)
dh = int(self.view.height * math.tan(ry) / 2)
return self.with_offset(dw, dh)
[docs]
def plot(self,
ax: Axes | None = None,
to_um: bool = True,
with_annotation: bool = False,
cbar: bool = False,
with_title: bool = False,
affine_transform: bool = False,
customized_trans: bool = False,
extent: tuple[float, float, float, float] | None = None,
**kwargs) -> tuple[AxesImage, AxesImage | None, CompositeGenericTransform]:
"""
:param ax: The Axes object on which to plot. If None, a new figure and axes are created.
:param to_um: A boolean flag indicating whether the coordinates should be converted to micrometers. Defaults to True.
Only applicable if ``extent`` is None.
:param with_annotation: A boolean indicating whether to include annotations in the plot.
:param cbar: A boolean indicating whether to include a color bar in the plot.
:param with_title: A boolean indicating whether to include a title in the plot.
:param affine_transform: A boolean indicating whether to apply an affine transformation to the plot.
:param customized_trans: A boolean indicating whether to use a customized affine transformation.
:param extent: A tuple defining the image boundaries (left, right, bottom, top). If None, boundaries are computed internally.
:param kwargs: Additional keyword arguments passed to ``ax.imshow``.
:return: A tuple containing the main AxesImage, the annotation AxesImage if any, and the Affine2D transformation.
"""
if ax is None:
_, ax = plt.subplots()
if extent is None:
extent = self._get_xy_range(to_um)
#
if affine_transform:
import matplotlib.transforms as mtransforms
#
if customized_trans:
aff = mtransforms.Affine2D(self._customized_affine_transform())
else:
aff = mtransforms.Affine2D().skew_deg(-20, 0)
aff_trans = aff + ax.transData
else:
aff_trans = ax.transData
#
image = self.image.astype(float)
image[image <= 10] = np.nan
im_view = ax.imshow(image, cmap='Greys', extent=extent, clip_on=False, transform=aff_trans, **kwargs)
# annotation
if with_annotation:
im_ann = self.plot_annotation(ax, aff_trans=aff_trans, to_um=to_um)
else:
im_ann = None
#
if cbar:
ax.figure.colorbar(im_view)
#
if with_title:
ax.set_title(f'{self.view.REFERENCE_FROM}: {self.reference_value} mm')
ax.set(xlabel=self.unit, ylabel=self.unit)
return im_view, im_ann, aff_trans
unit: str = 'a.u.'
def _get_xy_range(self, to_um: bool = True) -> tuple[float, float, float, float]:
if to_um:
x0 = -self.view.width_mm / 2 * 1000
x1 = self.view.width_mm / 2 * 1000
y0 = self.view.height_mm * 1000
y1 = 0
self.unit = 'um'
else:
x0 = -self.view.width_mm / 2
x1 = self.view.width_mm / 2
y0 = self.view.height
y1 = 0
self.unit = 'mm'
return x0, x1, y0, y1
def _customized_affine_transform(self) -> np.ndarray:
# translation
Y = np.array([[1, 0, 0], [0, 1, -4000], [0, 0, 1]])
# shear
tt = np.array([
[1, 0, 0],
[0, 1, 0],
[-0.03 / self.view.width, 0, 1]
])
# translation
_Y = np.array([[1, 0, 0], [0, 1, 3500], [0, 0, 1]])
return Y @ tt @ _Y
[docs]
def plot_annotation(self,
ax: Axes,
*,
aff_trans: Axes.transData | None = None,
to_um: bool = True,
cmap: str = 'binary',
alpha: float = 0.3,
extent: tuple[float, float, float, float] | None = None) -> AxesImage:
"""
Plot the annotation image
:param ax: ``Axes``
:param aff_trans: The transformation applied to the annotation image. If None, defaults to the Axes' transformation.
:param to_um: A boolean flag indicating whether the coordinates should be converted to micrometers. Defaults to True.
Only applicable if ``extent`` is None.
:param cmap: Colormap to be used for the annotation image. Defaults to 'binary'.
:param alpha: The imshow alpha, between 0 (transparent) and 1 (opaque). Defaults to 0.3.
:param extent: A tuple defining the image boundaries (left, right, bottom, top). If None, boundaries are computed internally.
:return: The AxesImage object created by imshow, representing the plotted annotation image.
"""
if extent is None:
extent = self._get_xy_range(to_um)
ann_img = (
load_slice_view('ccf_annotation', self.view.plane_type, allen_annotation_res=self.view.resolution)
.plane(self.plane_offset)
)
ann = (
ImageProcFactory(ann_img, 'RGB')
.cvt_gray()
.edge_detection(10, 0).image
)
ann = ann.astype(float)
ann[ann <= 10] = np.nan
if aff_trans is None:
aff_trans = ax.transData
im_ann = ax.imshow(ann, cmap=cmap, extent=extent, alpha=alpha, clip_on=False,
interpolation='none', vmin=0, vmax=255, transform=aff_trans)
return im_ann