Example notebook neuralib.model.rastermap (two-photon)
seealso rastermap and colab example
[1]:
from pathlib import Path
import attrs
import numpy as np
import rastermap.utils
from rastermap import Rastermap
from neuralib.io.dataset import load_example_rastermap_2p
from neuralib.model.rastermap import *
from neuralib.plot import plot_figure, ax_merge
from neuralib.typing import PathLike
[2]:
%load_ext autoreload
%autoreload
Example of 2P dataset pipeline
Linear treadmill task with cues (green ticks in the figure)
Visual stimulation epoch with drifting grating (pink areas in the figure)
Tracking of behavioral variables. running, velocity, pupil area
[3]:
# Prepare container for the imaging / behavioral data
@attrs.frozen
class RastermapInput:
"""
`Dimension parameters`:
N = number of neurons
T = number of image pulse
S = number of stimulation (optional)
"""
xy_pos: np.ndarray
"""Soma central position.`Array[float, [2, N]]` """
neural_activity: np.ndarray
"""2D Calcium activity. `Array[float, [N,T]]`"""
image_time: np.ndarray
"""1D Calcium imaging time. `Array[float, T]`"""
position: np.ndarray | None = attrs.field(default=None)
"""1D animal position. `Array[float, T]`"""
velocity: np.ndarray | None = attrs.field(default=None)
"""1D animal velocity. `Array[float, T]`"""
lap_index: np.ndarray | None = attrs.field(default=None)
"""1D trial index (laps in circular env). `Array[float, T]`"""
pupil_area: np.ndarray | None = attrs.field(default=None)
"""1D animal pupil area. `Array[float, T]`"""
visual_stim_time: np.ndarray | None = attrs.field(default=None)
"""2D on-off visual stimulation time. `Array[float, [S,2]]`"""
def __attrs_post_init__(self):
assert self.neural_activity.shape[1] == len(self.position) == len(self.velocity)
@property
def n_neurons(self) -> int:
return self.neural_activity.shape[0]
@property
def x_pos(self) -> np.ndarray:
return self.xy_pos[0]
@property
def y_pos(self) -> np.ndarray:
return self.xy_pos[1]
def find_cues_loc(self,
time_mask: np.ndarray | None = None,
*,
tolerance: float = 0.5,
cue_loc: tuple[float, ...] = (50, 100)) -> np.ndarray:
"""
find cue(s) location indices
**Note that the return is relative value to the ``time_mask`` (start from zero)**
:param time_mask: time mask. `Array[bool, T]`
:param tolerance: tolerance for interpolation of position finding the diff
:param cue_loc: Cue location
:return: Cue indices
"""
if time_mask is not None:
lap_index = self.lap_index[time_mask]
pos = self.position[time_mask]
else:
lap_index = self.lap_index
pos = self.position
split_index = np.where(np.diff(lap_index) > tolerance)[0]
x = np.split(pos, split_index + 1)
base_index = np.concatenate([np.array([0]), split_index])
ret = []
for loc in cue_loc:
cue_index = [np.searchsorted(it, loc) + b for it, b in zip(x, base_index) if np.max(it) > loc]
ret.append(cue_index)
ret = np.sort(np.concatenate(ret))
return ret
@property
def visual_stim_start(self) -> float:
return self.visual_stim_time[0, 0]
def visual_stim_trange(self, trange: tuple[float, float]) -> np.ndarray:
"""select visual stim time range segments
:return: `Array[float, [S, 2]]`
"""
vt = self.visual_stim_time # (N,2) -> (N*2)
start_idx, end_idx = np.searchsorted(vt.ravel(), list(trange))
# map to (N, 2)
start_stim_idx = int(start_idx // 2)
end_stim_idx = int(end_idx // 2)
if start_idx % 2 != 0:
vt[start_stim_idx, 0] = trange[0]
if end_idx % 2 != 0:
vt[end_stim_idx, 1] = trange[1]
return vt[start_stim_idx: end_stim_idx]
[4]:
# Prepare config dict
DEFAULT_2P_RASTERMAP_OPT: RasterOptions = {
'n_clusters': 50,
'n_PCs': 128,
'locality': 0.75,
'time_lag_window': 5,
'grid_upsample': 10,
}
[5]:
def run_rastermap_2p(dat: RastermapInput,
suite2p_directory: PathLike,
ops: RasterOptions | None = None,
neuron_bins: int = 40,
**kwargs) -> RasterMapResult:
if ops is None:
ops = DEFAULT_2P_RASTERMAP_OPT
model = Rastermap(
n_clusters=ops['n_clusters'],
n_PCs=ops['n_PCs'],
locality=ops['locality'],
time_lag_window=ops['time_lag_window'],
grid_upsample=ops['grid_upsample'],
**kwargs
).fit(dat.neural_activity)
embedding = model.embedding
isort = model.isort
sn = rastermap.utils.bin1d(dat.neural_activity[isort], bin_size=neuron_bins, axis=0)
# For fit gui launch behavior
_pseudo_cluster: UserCluster = {
'ids': np.arange(10),
'slice': slice(0, 10),
'binsize': neuron_bins,
'color': np.array([194.59459459, 255., 0., 50.])
}
ret = RasterMapResult(
filename=str(suite2p_directory / 'F.npy'),
save_path=str(suite2p_directory),
isort=isort,
embedding=embedding,
ops=ops,
user_clusters=[_pseudo_cluster],
super_neurons=sn
)
return ret
# ============================== #
# Plot cluster and soma location #
# ============================== #
def plot_rastermap_sort(dat: RastermapInput,
raster: RasterMapResult,
trange: tuple[int, int],
neuron_bins: int = 40):
"""
:param dat: input data
:param raster: computed result
:param trange: time range for visualization
:param neuron_bins:
:return:
"""
tmask = np.logical_and(trange[0] <= dat.image_time, dat.image_time <= trange[1])
time = dat.image_time[tmask]
with plot_figure(None, 10, 20, gridspec_kw={'wspace': 1, 'hspace': 1}, tight_layout=False) as _ax:
# position
ax1 = ax_merge(_ax)[0, :-1]
ax1.plot(time, dat.position[tmask], color='k')
ax1.axis('off')
ax1.set_title('position')
# running speed
ax2 = ax_merge(_ax)[1, :-1]
ax2.plot(time, dat.velocity[tmask], color='k')
ax2.axis('off')
ax2.set_title('running speed')
ax2.sharex(ax1)
# pupil
ax3 = ax_merge(_ax)[2, :-1]
ax3.plot(time, dat.pupil_area[tmask], color='k')
ax3.axis('off')
ax3.set_title('pupil area')
ax3.sharex(ax1)
# superneuron activity
ax4 = ax_merge(_ax)[3:, :-1]
ax4.sharex(ax1)
ax4.imshow(raster.super_neurons[:, tmask],
cmap="gray_r",
vmin=0,
vmax=0.8,
aspect="auto",
extent=(trange[0], trange[1], dat.n_neurons // neuron_bins, 0))
ax4.set(xlabel="time(s)", ylabel='superneurons')
# cue location
for c in dat.find_cues_loc(tmask):
ax4.axvline(time[c], color='g', linestyle='--', alpha=0.4)
# visual stim
if dat.visual_stim_start <= trange[1]:
for v in dat.visual_stim_trange(trange=trange):
ax4.axvspan(v[0], v[1], color='mistyrose', alpha=0.6)
# disable
ax5 = ax_merge(_ax)[:3, -1]
ax5.axis('off')
# color bar
ax6 = ax_merge(_ax)[3:, -1]
ax6.imshow(np.arange(0, raster.n_super)[:, np.newaxis], cmap="gist_ncar", aspect="auto")
ax6.axis("off")
def plot_rastermap_2p_soma(dat: RastermapInput,
raster: RasterMapResult,
output: PathLike | None = None):
with plot_figure(output) as ax:
ax.scatter(dat.xy_pos[0],
dat.xy_pos[1],
s=8, c=raster.embedding, cmap="gist_ncar", alpha=0.25)
ax.invert_yaxis()
ax.set(xlabel='X position(mm)', ylabel='Y position(mm)')
ax.set_aspect('equal')
[6]:
cache = load_example_rastermap_2p() # Replace to your own input data: RastermapInput(...)
raster_input = RastermapInput(
xy_pos=cache['xy_pos'],
neural_activity=cache['neural_activity'],
image_time=cache['image_time'],
position=cache['position'],
velocity=cache['velocity'],
lap_index=cache['lap_index'],
pupil_area=cache['pupil_area'],
visual_stim_time=cache['visual_stim_time']
)
[7]:
res = run_rastermap_2p(raster_input, Path(''))
2024-09-28 11:21:09,833 [INFO] normalizing data across axis=1
2024-09-28 11:21:10,349 [INFO] projecting out mean along axis=0
2024-09-28 11:21:10,689 [INFO] data normalized, 0.86sec
2024-09-28 11:21:10,710 [INFO] sorting activity: 1118 valid samples by 108128 timepoints
2024-09-28 11:21:18,339 [INFO] n_PCs = 128 computed, 8.51sec
2024-09-28 11:21:18,555 [INFO] 33 clusters computed, time 8.72sec
2024-09-28 11:21:18,946 [INFO] clusters sorted, time 9.11sec
2024-09-28 11:21:18,988 [INFO] clusters upsampled, time 9.16sec
2024-09-28 11:21:19,352 [INFO] rastermap complete, time 9.52sec
[8]:
plot_rastermap_sort(raster_input, res, trange=(850, 1000))
[9]:
plot_rastermap_2p_soma(raster_input, res)