1. Visualization of Spike Data#

神経細胞はスパイクと呼ばれる活動電位によって互いに情報を交換する.神経細胞集団の電気活動を分析する上で中心となるのはスパイク時刻の情報である.多電極アレイでは,各電極についてスパイク時刻の列 (spike train) が取得される.

解析の手始めに,可視化によってなるべく生の現象に近い状態でデータを観察する.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.size'] = 20
plt.rcParams['figure.dpi'] = 140
datadir = '../datasets/01/'
df_map = pd.read_csv(datadir + 'mapping.csv', index_col=0)
df_sp = pd.read_csv(datadir + 'spikes.csv', index_col=0)

df_mapは各電極の情報(x, y座標)を格納し,df_spは各スパイクの情報(検知された電極,検知された時刻)を格納する.

display(df_map.head())
display(df_sp.head())
channel x y
0 0 875.0 1505.0
1 1 3132.5 1242.5
2 2 647.5 1417.5
3 3 2870.0 1032.5
4 4 700.0 682.5
channel amplitude spiketime
0 342 18.561750 0.00000
1 382 22.953043 0.00005
2 708 36.348030 0.09220
3 824 42.317764 0.09220
4 852 46.456272 0.09225

1.1. Raster Plot#

スパイク時系列 (spike train) は, スパイク時刻と発生電極を両軸に取ったグラフ(ラスタープロットと呼ばれる)により可視化できる.

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 4))

# rasterplot
df_sp.query('10 < spiketime < 20').plot.scatter(ax=ax1, x='spiketime', y='channel', c='k', s=5)
ax1.set_title('Rasterplot')

# visualize electrode mapping
df_map.plot.scatter(ax=ax2, x='x', y='y', marker='s', c='k')
ax2.set_title('Electrode Mapping')
ax2.set_aspect('equal')

plt.show()
../_images/01_visualization_8_0.png
fig, ax = plt.subplots(figsize=(16, 4))
df_sp.query('13.8 < spiketime < 14.0').plot.scatter(ax=ax, x='spiketime', y='channel', c='k', s=5)

plt.locator_params(axis='x', nbins=4)
plt.show()
../_images/01_visualization_9_0.png

1.2. Global Firing Rate#

電極全体での発火頻度(global firing rate)の時間的な推移を,ラスタープロットに重ねて表示したいので,発火時刻についてヒストグラムをとる.

Tip

発火頻度のヒストグラムは,ガウシアンフィルタ等によって平滑化すると,細かな変動による影響を緩和できて扱いやすくなる.のちに刺激に対する応答をみる際や,同期バースト検知の際にも有効である.

from scipy.ndimage import gaussian_filter

def rasterplot(ax: plt.Axes, df: pd.DataFrame, start: float, end: float, x: str='spiketime', y: str='channel', **kwargs):
    df_ = df.query(f'{start} < {x} < {end}')
    return ax.scatter(x=df_[x], y=df_[y], **kwargs)

def spikehist(ax: plt.Axes, df: pd.DataFrame, start: float, end: float, bin_width: float, x: str='spiketime', y: str='channel', smooth=True, **kwargs):
    df_ = df.query(f'{start} < {x} < {end}')
    hist, edges = np.histogram(df_[x], range=(start, end), bins=int((end-start)/bin_width))
    if smooth: hist = gaussian_filter(hist, sigma=[2])  # smoothing with gaussian filter
    return ax .plot(edges[1:], hist, **kwargs)

def rastergram(ax1: plt.Axes, ax2: plt.Axes, df: pd.DataFrame, start: float, end: float, 
               bin_width: float=0.01, x: str='spiketime', y: str='channel', smooth: bool=True):
    p1 = spikehist(ax1, df, start, end, bin_width, x, y, smooth, linewidth=2.0, c='k')
    p2 = rasterplot(ax2, df, start, end, x, y, c='k', s=5)
    return p1, p2
fig, (ax1, ax2) = plt.subplots(2, figsize=(16, 4), gridspec_kw={'height_ratios': [1, 3]})
start, end = 12.0, 14.5
rastergram(ax1=ax1, ax2=ax2, df=df_sp, start=start, end=end, bin_width=0.001)

# optional: touch up the figure layout
ax1.set_xlim(start, end)
ax1.set_xticks([])
ax1.set_ylabel('spikes')
ax1.spines['right'].set_visible(False)
ax1.spines['top'].set_visible(False)
ax1.spines['bottom'].set_visible(False)
ax1.spines['left'].set_linewidth(2)
ax1.tick_params(width=2.0, length=5.0, direction='in')

ax2.spines['right'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.spines['left'].set_linewidth(2)
ax2.spines['bottom'].set_linewidth(2)

ax2.set_xlim(start, end)
ax2.set_xlabel('time [s]')

for i, tick in enumerate(ax2.xaxis.get_ticklabels()):
    if i % 2 != 0:
        tick.set_visible(False) 

ax2.set_ylim(0, 1024)
ax2.set_ylabel('channel #')
ax2.set_yticks([0, 250, 500, 750, 1000])
for i, tick in enumerate(ax2.yaxis.get_ticklabels()):
    if i % 2 != 0:
        tick.set_visible(False) 

ax2.set_facecolor('whitesmoke')
ax2.tick_params(width=2.0, length=5.0, direction='in')

plt.subplots_adjust(hspace=0.2)
plt.show()
../_images/01_visualization_14_0.png

1.3. Electrode Mapping#

次に,MEA上の各電極について神経活動に関する統計量(発火率,平均振幅)を取得し,ヒートマップにより可視化する.

def channel_stats(df_sp: pd.DataFrame, df_map: pd.DataFrame):
    duration = df_sp.spiketime.max() - df_sp.spiketime.min()
    groups = df_sp[['channel', 'amplitude']].groupby('channel')
    
    df_fr = pd.DataFrame(groups.size() / duration, columns=['firing_rate'])  # firing rate for each channel
    df_amp = groups.mean()  # mean spike amplitude for each channel
    
    df_stat = pd.concat([df_map.set_index('channel'), df_fr, df_amp], axis=1, join='inner')
    return df_stat
df_stat = channel_stats(df_sp, df_map)
display(df_stat.head())
x y firing_rate amplitude
channel
0 875.0 1505.0 2.303333 68.872199
1 3132.5 1242.5 0.976667 153.558120
2 647.5 1417.5 2.800000 332.676633
3 2870.0 1032.5 0.450000 37.887843
4 700.0 682.5 1.723333 299.170993
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 4))

# mean amplitude
df_stat.plot.scatter(ax=ax1, x='x', y='y', marker='s', s=1, c='amplitude', cmap='viridis', vmin=0.0, vmax=600.0)
ax1.set_aspect('equal')
ax1.set_facecolor('k')
ax1.grid()

# mean firing rate
df_stat.plot.scatter(ax=ax2, x='x', y='y', marker='s', s=1, c='firing_rate', cmap='magma', vmin=0.0, vmax=5.0)
ax2.set_aspect('equal')
ax2.set_facecolor('k')
ax2.grid()

plt.tight_layout()
plt.show()
../_images/01_visualization_19_0.png

上では見づらいので,一部を拡大してみよう.ついでにchannel idもアノテーションしておく.

fig, ax = plt.subplots(figsize=(8, 6))

df_ = df_stat.query('500 < x < 800 and 1800 < y < 2000')
df_.plot.scatter(ax=ax, x='x', y='y', marker='s', s=600, c='amplitude', cmap='viridis', vmin=0.0, vmax=600.0)

for index, row in df_.iterrows():
    ax.annotate(text=str(index), xy=(row.x, row.y), fontsize=12, color='w', ha='center', va='center')
    
ax.set_aspect('equal')
ax.set_facecolor('k')    

plt.locator_params(axis='y', nbins=4)
plt.show() 
../_images/01_visualization_21_0.png

1.4. Propagation#

matplotlib animationを用いて,spikeが電極上で伝播していく様子を可視化するコードの例を示す.

Note

基本的な描画は通常のグラフと同じであるが,それらを一定のfpsでフレームとして結合することにより,アニメーションを作る.

Tip

以下のコードでは,特定の区間において,窓(幅20ms)を1msずつスライドしながら,相対的な発火時刻(spiketime)に基づいてカラー散布図を描画する.spikeのプロットが一定時間残るため,伝播の方向を視認しやすい.

# animation
from matplotlib.animation import ArtistAnimation
from IPython import display 
start, end = 13.880, 14.000
df_ = df_sp.query('@start <= spiketime <= @end')

# slides settings
frames = []
bin_width = 0.020
slide_width = 0.001
n_slides = int(((end - start) - bin_width) / slide_width) + 1
print(f'number of frames: {n_slides}')

# plot initialization
fig, ax = plt.subplots(figsize=(8, 5))
s = 10
ax.set_facecolor('black')
ax.set_xlim(0.0, 17.5*220)
ax.set_ylim(0.0, 17.5*120)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect('equal')

# animated parts
for i in range(n_slides):
    # extract only spikes within each sliding window
    start_ = start + i * slide_width
    end_ = start_ + bin_width
    df = pd.merge(df_.query('@start_ <= spiketime <= @end_'), df_map, on='channel')

    # scatter plot based on relative spiketime
    vmin, vmax = start_, end_
    sc = ax.scatter(x=df.x, y=df.y, vmin=vmin, vmax=vmax, marker='s', c=df.spiketime, cmap=plt.get_cmap('viridis'), s=s, edgecolor='k')
    title = ax.text(0.5, 1.01, 'time: {:.3f} [s]'.format(vmax), ha='center', va='bottom', transform=ax.transAxes, fontsize=25)
    frames.append([sc, title])

# drawing
fps = 15
ani = ArtistAnimation(fig, frames, interval=int(1000/fps))
html = display.HTML(ani.to_jshtml())
display.display(html)
plt.close()
number of frames: 100