"""
Swc morphology plot
====================
**Example for CLI usage**
.. prompt:: bash $
python -m neuralib.morpho.swc -h
**3D view with radius**
- Press ``Shift+s`` for vedo interactive saving
.. prompt:: bash $
python -m neuralib.morpho.swc <SWC_FILE> --radius
**2D view with radius**
.. prompt:: bash $
python -m neuralib.morpho.swc <SWC_FILE> --radius --2d
"""
from pathlib import Path
from typing import NamedTuple, Iterator
import matplotlib.pyplot as plt
import numpy as np
import vedo
from matplotlib.axes import Axes
from matplotlib.patches import Circle
from typing_extensions import Self, overload
from neuralib.argp import AbstractParser, argument
from neuralib.typing import PathLike
__all__ = [
'SwcNode',
'SwcFile',
'plot_swc'
]
Identifier = int
IdentifierName = str
IDENTIFIER_DICT: dict[Identifier, IdentifierName] = {
0: 'undefined',
1: 'soma',
2: 'axon',
3: 'basal',
4: 'apical',
5: 'custom'
}
[docs]
class SwcNode(NamedTuple):
n: int
"""node number"""
identifier: Identifier
"""See IDENTIFIER_DICT"""
x: float
"""position x"""
y: float
"""position y"""
z: float
"""position z"""
r: float
"""radius"""
parent: int
"""parent connectivity"""
@property
def identifier_name(self) -> IdentifierName:
return IDENTIFIER_DICT.get(self.identifier, 'custom')
@property
def point(self) -> np.ndarray:
return np.array([self.x, self.y, self.z])
@property
def is_undefined(self) -> bool:
return self.identifier == 0
@property
def is_soma(self) -> bool:
return self.identifier == 1
@property
def is_axon(self) -> bool:
return self.identifier == 2
@property
def is_basal_dendrite(self) -> bool:
return self.identifier == 3
@property
def is_apical_dendrite(self) -> bool:
return self.identifier == 4
@property
def is_dendrite(self) -> bool:
return self.is_basal_dendrite or self.is_apical_dendrite
@property
def is_custom(self) -> bool:
return self.identifier >= 5
[docs]
class SwcFile:
"""SWC File"""
node: list[SwcNode]
[docs]
def __init__(self, node: list[SwcNode]):
self.node = node
[docs]
@classmethod
def load(cls, file: PathLike) -> Self:
"""
:param file: swc filepath
:return: ``SwcFile``
"""
node = []
with Path(file).open('r', encoding='Big5') as f:
for line in f:
line = line.strip()
if len(line) == 0 or line.startswith('#'):
continue
part = line.split()
n = int(part[0])
i = int(part[1])
x = float(part[2])
y = float(part[3])
z = float(part[4])
r = float(part[5])
p = int(part[6])
node.append(SwcNode(n, i, x, y, z, r, p))
return cls(node)
def __str__(self):
line = [str(node) for node in self.node]
return '\n'.join(line)
@overload
def __getitem__(self, item: int) -> SwcNode:
pass
@overload
def __getitem__(self, item: IdentifierName) -> Self:
pass
def __getitem__(self, item: int | str) -> SwcNode | Self:
if isinstance(item, int):
try:
ret = self.node[item - 1] # to index
except IndexError:
ret = None
if ret is not None and ret.n == item:
return ret
raise KeyError(f'item not found: {item}, might be loss parent connection')
elif isinstance(item, str):
if item == 'soma':
node = [n for n in self.foreach_node() if n.is_soma]
elif item == 'axon':
node = [n for n in self.foreach_node() if n.is_axon]
elif item == 'dendrite':
node = [n for n in self.foreach_node() if n.is_dendrite]
elif item == 'basal':
node = [n for n in self.foreach_node() if n.is_basal_dendrite]
elif item == 'apical':
node = [n for n in self.foreach_node() if n.is_apical_dendrite]
elif item == 'dendrite':
node = [n for n in self.foreach_node() if n.is_dendrite]
elif item == 'custom':
node = [n for n in self.foreach_node() if n.is_custom]
elif item == 'undefined':
node = [n for n in self.foreach_node() if n.is_undefined]
else:
raise ValueError('')
return SwcFile(node)
else:
raise TypeError(f'item must be int or str: {type(item)}')
@property
def points(self) -> np.ndarray:
return np.array([[n.x, n.y, n.z] for n in self.foreach_node()])
@property
def radii(self) -> np.ndarray:
return np.array([n.r for n in self.foreach_node()])
@property
def parents(self) -> np.ndarray:
return np.array([n.parent for n in self.foreach_node()])
@property
def unique_identifier(self) -> list[IdentifierName]:
idfs = np.unique([n.identifier for n in self.foreach_node()])
return [
IDENTIFIER_DICT.get(idf, 'custom')
for idf in idfs
]
[docs]
def foreach_identifier(self, as_dict: bool) -> list[Self] | dict[str, Self]:
if as_dict:
return {idf: self[idf] for idf in self.unique_identifier}
else:
return [self[idf] for idf in self.unique_identifier]
[docs]
def foreach_node(self) -> Iterator[SwcNode]:
for node in self.node:
yield node
[docs]
def foreach_line(self) -> Iterator[tuple[SwcNode, SwcNode]]:
for node in self.node:
if node.parent > 0:
yield node, self[node.parent]
# ============== #
# Plot Functions #
# ============== #
Point3D = tuple[float, float, float]
Point2D = tuple[float, float]
DEFAULT_COLOR: dict[IdentifierName, str] = {
'soma': 'b',
'axon': 'r',
'dendrite': 'k',
'undefined': 'k',
'custom': 'k'
}
def projection_2d(p: Point3D) -> Point2D:
"""Default projection function, remove z value.
:param p: 3d points
:return: 2d points
"""
return p[0], p[1]
def smooth_line_radius(ax: Axes,
p1: Point2D,
p2: Point2D,
r1: float,
r2: float,
num: int = 2,
**kwargs):
"""
:param ax: ``Axes``
:param p1: Point 1
:param p2: Point 2
:param r1: Radius 1
:param r2: Radius 2
:param num: Number of segments
:param kwargs: Additional arguments pass to ``plt.plot()``
:return:
"""
px = np.linspace(p1[0], p2[0], num + 1)
py = np.linspace(p1[1], p2[1], num + 1)
lw = np.linspace(r1, r2, num)
for i in range(num):
ax.plot(px[i:i + 2], py[i:i + 2], lw=lw[i], **kwargs)
[docs]
def plot_swc(swc: SwcFile,
radius: bool = True,
color: dict[str, str] | None = None,
as_2d: bool = False):
"""
Plot swc file as 2d
:param swc: ``SwcFile``
:param radius: Plot with radius.
:param color: Color dict. With {identifier name: color coded}
:param as_2d:
"""
if color is None:
color = DEFAULT_COLOR
if as_2d:
_plot_swc_2d(swc, radius, color)
else:
_plot_swc_3d(swc, radius, color)
def _plot_swc_2d(swc, radius, color):
fig, ax = plt.subplots()
for n1, n2 in swc.foreach_line():
c = color.get(n1.identifier_name, 'k')
p1 = projection_2d((n1.x, n1.y, n1.z))
p2 = projection_2d((n2.x, n2.y, n2.z))
if radius:
if n2.is_soma:
ax.add_artist(Circle(p2, n2.r, color=color['soma']))
if not n1.is_soma:
smooth_line_radius(ax, p1, p2, n1.r, n1.r, color=c, solid_capstyle='round')
else:
smooth_line_radius(ax, p1, p2, n1.r, n2.r, color=c, solid_capstyle='round')
else:
px = p1[0], p2[0]
py = p1[1], p2[1]
ax.plot(px, py, color=c, solid_capstyle='round')
ax.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
def _plot_swc_3d(swc: SwcFile,
radius,
color,
spheres_size: float = 3,
lw: float = 5):
plotter = vedo.Plotter()
axons = []
axons_line = []
axons_radii = []
dendrites = []
dendrites_line = []
dendrites_radii = []
somata = []
somata_line = []
somata_radii = []
other = []
other_line = []
other_radii = []
for i, n in enumerate(swc.foreach_node()):
if n.parent == -1:
continue
r = n.r * spheres_size if radius else 5
if n.is_axon:
axons.append([n.x, n.y, n.z])
axons_line.append([n.parent - 1, i]) # Use parent-child connection for axons
axons_radii.append(r)
elif n.is_dendrite:
dendrites.append([n.x, n.y, n.z])
dendrites_line.append([n.parent - 1, i])
dendrites_radii.append(r)
elif n.is_soma:
somata.append([n.x, n.y, n.z])
somata_line.append([n.parent - 1, i])
somata_radii.append(10) # fix value
elif n.is_undefined or n.is_custom:
other.append([n.x, n.y, n.z])
other_line.append([n.parent - 1, i])
other_radii.append(r)
#
if 'soma' in swc.unique_identifier:
soma_spheres = vedo.Spheres(somata, r=somata_radii, c=color['soma'])
soma_lines = vedo.Lines(swc.points[somata_line], c=color['soma'], lw=lw)
plotter += soma_spheres
plotter += soma_lines
if 'dendrite' in swc.unique_identifier:
dendrite_spheres = vedo.Spheres(dendrites, r=dendrites_radii, c=color['dendrite'])
dendrite_lines = vedo.Lines(swc.points[dendrites_line], c=color['dendrite'], lw=lw)
plotter += dendrite_spheres
plotter += dendrite_lines
if 'axon' in swc.unique_identifier:
axon_spheres = vedo.Spheres(axons, r=axons_radii, c=color['axon'])
axon_lines = vedo.Lines(swc.points[axons_line], c=color['axon'], lw=lw)
plotter += axon_spheres
plotter += axon_lines
other_spheres = vedo.Spheres(other, r=other_radii, c=color.get('custom', 'k'))
other_lines = vedo.Lines(swc.points[other_line], c=color.get('custom', 'k'), lw=lw)
plotter += other_spheres
plotter += other_lines
plotter.show()
# ======== #
# Plot CLI #
# ======== #
class SwcPlotOptions(AbstractParser):
file: str = argument(
metavar='FILE',
help='filepath of the swc file'
)
radius: bool = argument(
'-R', '--radius',
help='Whether plot with radius'
)
as_2d: bool = argument(
'--2d',
help='Whether plot with 2d, otherwise, plot as 3d'
)
def run(self):
swc = SwcFile.load(self.file)
plot_swc(swc, radius=self.radius, as_2d=self.as_2d)
if __name__ == '__main__':
SwcPlotOptions().main()