import dataclasses
from functools import cached_property
from pathlib import Path
from typing import Optional
import numpy as np
import polars as pl
from scipy.interpolate import interp1d
from typing_extensions import Self
from neuralib.argp import as_argument, argument
from neuralib.atlas.brainrender.core import BrainReconstructor
from neuralib.atlas.util import PLANE_TYPE
from neuralib.atlas.util import roi_points_converter
from neuralib.util.segments import grouped_iter
__all__ = ['ProbeReconstructor']
[docs]
class ProbeReconstructor(BrainReconstructor):
DESCRIPTION = 'For probe(s) track reconstruction'
implant_depth: int = argument('-D', '--depth', required=True, help='implant depth in um')
dye_label_only: bool = argument('--dye', help='only show the histology dye parts')
csv_file: Path = as_argument(BrainReconstructor.csv_file).with_options(
required=True,
help='csv file after registration using 2dccf pipeline, point numbers equal to probe(s) * 2'
)
"""
Example::
┌───────────────────────────────────┬─────────┬─────────────┬─────────────┬─────────────┬─────────┐
│ name ┆ acronym ┆ AP_location ┆ DV_location ┆ ML_location ┆ avIndex │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ i64 │
╞═══════════════════════════════════╪═════════╪═════════════╪═════════════╪═════════════╪═════════╡
│ Primary visual area layer 6a ┆ VISp6a ┆ -3.81 ┆ 1.92 ┆ -3.12 ┆ 191 │
│ optic radiation ┆ or ┆ -4.08 ┆ 2.33 ┆ -3.12 ┆ 1217 │
│ Posterolateral visual area layer… ┆ VISpl6a ┆ -4.28 ┆ 2.29 ┆ -3.12 ┆ 198 │
│ Posterolateral visual area layer… ┆ VISpl5 ┆ -4.52 ┆ 2.17 ┆ -3.12 ┆ 197 │
│ Subiculum ┆ SUB ┆ -3.93 ┆ 4.36 ┆ -3.3 ┆ 536 │
│ Entorhinal area medial part dors… ┆ ENTm5 ┆ -4.19 ┆ 4.39 ┆ -3.3 ┆ 515 │
│ Entorhinal area medial part dors… ┆ ENTm2 ┆ -4.44 ┆ 4.39 ┆ -3.3 ┆ 510 │
│ Entorhinal area medial part dors… ┆ ENTm1 ┆ -4.66 ┆ 4.29 ┆ -3.3 ┆ 509 │
└───────────────────────────────────┴─────────┴─────────────┴─────────────┴─────────────┴─────────┘
"""
plane_type: PLANE_TYPE = argument(
'--plane-type', '-P',
default='coronal',
help='cutting orientation. Assume if multiple shanks were inserted along the AP axis, then do the sagittal'
'slicing, if inserted along the ML axis, then do the coronal slicing '
)
raw: pl.DataFrame
"""raw csv file"""
data: pl.DataFrame
"""sorted data based on plane_type"""
[docs]
def load(self):
self.raw = pl.read_csv(self.csv_file)
self.data = self.infer_probe_index(self.raw)
self.dye_label_only = True
probe_dye = self.get_probe_object().shanks
probe_dye = np.vstack(probe_dye)
self.dye_label_only = False
probe_theo = self.get_probe_object().with_theoretical().shanks
probe_theo = np.vstack(probe_theo)
self.add_points([probe_dye, probe_theo])
@cached_property
def number_shanks(self) -> int:
return int(self.raw.shape[0] / 2)
[docs]
@dataclasses.dataclass
class ShanksTrack:
"""shank object
`Dimension parameters`:
S = number of shanks
P = number of sample points after interpolated reconstruction
"""
rst: 'ProbeReconstructor'
"""`ProbeReconstructor`"""
shanks: list[np.ndarray]
"""length S of Array[float, [P, 3]], with ap, dv, ml"""
def __post_init__(self):
assert len(self.shanks) == self.rst.number_shanks
def __len__(self) -> int:
"""S"""
return len(self.shanks)
def __getitem__(self, idx: int) -> np.ndarray:
"""get a shank array. `Array[float, [P, 3]]`"""
return self.shanks[idx]
[docs]
def with_theoretical(self, interval: int = 250) -> Self:
"""
theoretical track based on implantation depth / angle
:param interval: istance (um) relative to the specific shank. e.g., NeuroPixel 2.0 = 250 * x
:return: :class:`ShanksTrack`
"""
if not self.rst.dye_label_only:
crop = self.rst.crop_outside_brain
depth = self.rst.implant_depth
ret = []
s1 = self[0]
for i in range(len(self)):
s_cur = self[i]
p = crop(s_cur, depth, dv_value=int(s1[0, 1]))
ret.append(p)
depth += _calc_shank_length_diff(s1, interval)
interval += interval
return dataclasses.replace(self, shanks=ret)
raise RuntimeError('')
[docs]
def get_probe_object(self) -> ShanksTrack:
p = roi_points_converter(self.data)
n_label_points = p.shape[0]
ext_depth = None if self.dye_label_only else (0, 5000)
def _extend(p1, p2):
return self._shank_extend(np.array([p1, p2]), ext_depth=ext_depth)
ret = []
for (surface_idx, tip_idx) in grouped_iter(np.arange(n_label_points), 2):
ret.append(_extend(p[surface_idx], p[tip_idx]))
return self.ShanksTrack(self, ret)
[docs]
def infer_probe_index(self, df: pl.DataFrame) -> pl.DataFrame:
"""
probe in correct order index
:param df: raw csv file
:return: dataframe with `probe` and `probe_idx` fields.
Example::
┌───────────────────────────────────┬─────────┬─────────────┬─────────────┬─────────────┬─────────┬───────────────┬───────────┐
│ name ┆ acronym ┆ AP_location ┆ DV_location ┆ ML_location ┆ avIndex ┆ probe ┆ probe_idx │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ i64 ┆ str ┆ i64 │
╞═══════════════════════════════════╪═════════╪═════════════╪═════════════╪═════════════╪═════════╪═══════════════╪═══════════╡
│ Primary visual area layer 6a ┆ VISp6a ┆ -3.81 ┆ 1.92 ┆ -3.12 ┆ 191 ┆ dorsal_label ┆ 1 │
│ Subiculum ┆ SUB ┆ -3.93 ┆ 4.36 ┆ -3.3 ┆ 536 ┆ ventral_label ┆ 1 │
│ optic radiation ┆ or ┆ -4.08 ┆ 2.33 ┆ -3.12 ┆ 1217 ┆ dorsal_label ┆ 2 │
│ Entorhinal area medial part dors… ┆ ENTm5 ┆ -4.19 ┆ 4.39 ┆ -3.3 ┆ 515 ┆ ventral_label ┆ 2 │
│ Posterolateral visual area layer… ┆ VISpl6a ┆ -4.28 ┆ 2.29 ┆ -3.12 ┆ 198 ┆ dorsal_label ┆ 3 │
│ Entorhinal area medial part dors… ┆ ENTm2 ┆ -4.44 ┆ 4.39 ┆ -3.3 ┆ 510 ┆ ventral_label ┆ 3 │
│ Posterolateral visual area layer… ┆ VISpl5 ┆ -4.52 ┆ 2.17 ┆ -3.12 ┆ 197 ┆ dorsal_label ┆ 4 │
│ Entorhinal area medial part dors… ┆ ENTm1 ┆ -4.66 ┆ 4.29 ┆ -3.3 ┆ 509 ┆ ventral_label ┆ 4 │
└───────────────────────────────────┴─────────┴─────────────┴─────────────┴─────────────┴─────────┴───────────────┴───────────┘
"""
n_shanks = self.number_shanks
df = (df.sort('DV_location')
.with_columns(pl.Series(['dorsal_label'] * n_shanks + ['ventral_label'] * n_shanks).alias('probe')))
if self.plane_type == 'sagittal':
probe_order = 'AP_location'
elif self.plane_type == 'coronal':
probe_order = 'ML_location'
else:
raise RuntimeError('')
df = (df.sort(by=['probe', probe_order], descending=[False, True])
.with_columns(pl.Series(list(range(1, 1 + n_shanks)) * 2).alias('probe_idx'))
.sort('probe_idx'))
print(df)
return df
[docs]
def isin_brain(self, shank: np.ndarray) -> np.ndarray:
"""
determine if the probe points are in the brain
:param shank: `Array[float, [P, 3]]`
:return:
"""
brain = self.get_atlas_brain_globe()
ret = []
for sh in shank:
try:
s = brain.structure_from_coords(sh, microns=True)
except IndexError:
s = 0
ret.append(s != 0)
return np.array(ret, dtype=bool)
[docs]
def crop_outside_brain(self, shank: np.ndarray,
distance: float,
dv_value: Optional[int] = None) -> np.ndarray:
"""
crop the probe after doing the extension
:param shank: `Array[float, [P, 3]]`
:param distance: depth of insertion (might with an angle, in um). mostly records during the implantation
use the depth value that used while implantation to cutoff the bottom line.
:param dv_value: if None, plot the probe if its in the brain
if int type, plot the probe if dv larger than this value
:return:
"""
shank = shank[shank[:, 1] >= 0]
if dv_value is None:
m = self.isin_brain(shank)
elif isinstance(dv_value, int):
m = shank[:, 1] >= dv_value
else:
raise TypeError('')
shank = shank[m]
d = np.sqrt(np.sum((shank - shank[0]) ** 2, axis=1))
return shank[d <= distance]
@staticmethod
def _shank_extend(shank: np.ndarray,
ext_depth: Optional[tuple[float, float]] = (0, 5000)) -> np.ndarray:
"""
probe extension using extrapolation and interpolation
:param shank: `Array[float, [2, 3]]`, 2: start and end points; 3: AP,DV,ML
:param ext_depth: depth in um, if None, only do the interpolation of the labelled points
:return: extended shank
"""
if ext_depth is not None:
nn = np.arange(*ext_depth, 10)
else:
nn = np.arange(shank[0, 1], shank[-1, 1], 10)
return interp1d(shank[:, 1], shank, axis=0, bounds_error=False, fill_value='extrapolate')(nn)
# ========= #
def _calc_shank_length_diff(shank: np.ndarray,
shank_interval: float) -> float:
"""
use the vector of the probe, then calculate the unit vector
:param shank: `Array[float, [P, 3]]`
:param shank_interval: distance (um) relative to the specific shank. e.g., NeuroPixel 2.0 = 250 * x
:return: unit vector
"""
v = shank[-1] - shank[0]
v = v / np.linalg.norm(v) # unit vector
# vector product: |n x v| = sin_theta |n| |v|
# inline n value, got following formula
sin_theta = np.linalg.norm(np.array([v[2], 0, -v[0]]))
return shank_interval * sin_theta
if __name__ == '__main__':
ProbeReconstructor().main()