Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion doc/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ important for your research:

Curation Module
---------------
If you use the :code:`get_potential_auto_merge` method from the curation module, please cite [Llobet]_

If you use the default "similarity_correlograms" preset in the :code:`compute_merge_unit_groups` method from the curation module, please cite [Llobet]_

If you use the "slay" preset in the :code:`compute_merge_unit_groups` method, please cite [Koukuntla]_

If you use :code:`auto_label_units` or :code:`train_model`, please cite [Jain]_

Expand Down Expand Up @@ -140,6 +143,8 @@ References

.. [Jia] `High-density extracellular probes reveal dendritic backpropagation and facilitate neuron classification. 2019 <https://journals.physiology.org/doi/full/10.1152/jn.00680.2018>`_

.. [Koukuntla] `SLAy-ing oversplitting errors in high-density electrophysiology spike sorting. 2025. <https://www.biorxiv.org/content/10.1101/2025.06.20.660590v1>`_

.. [Lee] `YASS: Yet another spike sorter. 2017. <https://www.biorxiv.org/content/10.1101/151928v1>`_

.. [Lemon] Methods for neuronal recording in conscious animals. IBRO Handbook Series. 1984.
Expand Down
262 changes: 262 additions & 0 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
"knn",
"quality_score",
],
"slay": [
"template_similarity",
"slay_score",
],
}

_required_extensions = {
Expand All @@ -60,6 +64,7 @@
"snr": ["templates", "noise_levels"],
"template_similarity": ["templates", "template_similarity"],
"knn": ["templates", "spike_locations", "spike_amplitudes"],
"slay_score": ["correlograms", "template_similarity"],
}


Expand All @@ -84,6 +89,7 @@
"censored_period_ms": 0.3,
},
"quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3},
"slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5},
}


Expand Down Expand Up @@ -356,6 +362,14 @@ def compute_merge_unit_groups(
)
outs["pairs_decreased_score"] = pairs_decreased_score

elif step == "slay_score":

M_ij = compute_slay_matrix(
sorting_analyzer, params["k1"], params["k2"], templates_diff=outs["templates_diff"], pair_mask=pair_mask
)

pair_mask = pair_mask & (M_ij > params["slay_threshold"])

# FINAL STEP : create the final list from pair_mask boolean matrix
ind1, ind2 = np.nonzero(pair_mask)
merge_unit_groups = list(zip(unit_ids[ind1], unit_ids[ind2]))
Expand Down Expand Up @@ -1506,3 +1520,251 @@ def estimate_cross_contamination(
)

return estimation, p_value


def compute_slay_matrix(
sorting_analyzer: SortingAnalyzer,
k1: float,
k2: float,
templates_diff: np.ndarray | None,
pair_mask: np.ndarray | None = None,
):
"""
Computes the "merge decision metric" from the SLAy method, made from combining
a template similarity measure, a cross-correlation significance measure and a
sliding refractory period violation measure. A large M suggests that two
units should be merged.

Paramters
---------
sorting_analyzer : SortingAnalyzer
The sorting analyzer object containing the spike sorting data
k1 : float
Coefficient determining the importance of the cross-correlation significance
k2 : float
Coefficient determining the importance of the sliding rp violation
templates_diff : np.ndarray | None
Pre-computed template similarity difference matrix. If None, it will be retrieved from the sorting_analyzer.
pair_mask : None | np.ndarray, default: None
A bool matrix describing which pairs are possible merges based on previous steps


References
----------
Based on computation originally implemented in SLAy [Koukuntla]_.

Implementation is based on one of the original implementations written by Sai Koukuntla,
found at https://github.com/saikoukunt/SLAy.
"""

num_units = sorting_analyzer.get_num_units()

if pair_mask is None:
pair_mask = np.triu(np.arange(num_units), 1) > 0

if templates_diff is not None:
sigma_ij = 1 - templates_diff
else:
sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data()
rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask)

M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij

return M_ij


def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarray):
"""
Computes a cross-correlation significance measure and a sliding refractory period violation
measure for all units in the `sorting_analyzer`.

Paramters
---------
sorting_analyzer : SortingAnalyzer
The sorting analyzer object containing the spike sorting data
pair_mask : np.ndarray
A bool matrix describing which pairs are possible merges based on previous steps
"""

correlograms_extension = sorting_analyzer.get_extension("correlograms")
ccgs, _ = correlograms_extension.get_data()

# convert to seconds for SLAy functions
bin_size_ms = correlograms_extension.params["bin_ms"]

rho_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)])
eta_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)])

for unit_index_1, _ in enumerate(sorting_analyzer.unit_ids):
for unit_index_2, _ in enumerate(sorting_analyzer.unit_ids):

# Don't waste time computing the other metrics if units not candidates merges
if not pair_mask[unit_index_1, unit_index_2]:
continue

xgram = ccgs[unit_index_1, unit_index_2, :]

rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair(
xgram, bin_size_s=bin_size_ms / 1000, min_xcorr_rate=0
)
eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(xgram, bin_size_ms=bin_size_ms)

return rho_ij, eta_ij


def _compute_xcorr_pair(
xgram,
bin_size_s: float,
min_xcorr_rate: float,
) -> float:
"""
Calculates a cross-correlation significance metric for a cluster pair.

Uses the wasserstein distance between an observed cross-correlogram and a null
distribution as an estimate of how significant the dependence between
two neurons is. Low spike count cross-correlograms have large wasserstein
distances from null by chance, so we first try to expand the window size. If
that fails to yield enough spikes, we apply a penalty to the metric.

Ported from https://github.com/saikoukunt/SLAy.

Parameters
----------
xgram : np.array
The raw cross-correlogram for the cluster pair.
bin_size_s : float
The width in seconds of the bin size of the input ccgs.
min_xcorr_rate : float
The minimum ccg firing rate in Hz.

Returns
-------
sig : float
The calculated cross-correlation significance metric.
"""

from scipy.signal import butter, find_peaks_cwt, sosfiltfilt
from scipy.stats import wasserstein_distance

# calculate low-pass filtered second derivative of ccg
fs = 1 / bin_size_s
cutoff_freq = 100
nyqist = fs / 2
cutoff = cutoff_freq / nyqist
peak_width = 0.002 / bin_size_s

xgram_2d = np.diff(xgram, 2)
sos = butter(4, cutoff, output="sos")
xgram_2d = sosfiltfilt(sos, xgram_2d)

if xgram.sum() == 0:
return 0

# find negative peaks of second derivative of ccg, these are the edges of dips in ccg
peaks = find_peaks_cwt(-xgram_2d, peak_width, noise_perc=90) + 1
# if no peaks are found, return a very low significance
if peaks.shape[0] == 0:
return -4
peaks = np.abs(peaks - xgram.shape[0] / 2)
peaks = peaks[peaks > 0.5 * peak_width]
min_peaks = np.sort(peaks)

# start with peaks closest to 0 and move to the next set of peaks if the event count is too low
window_width = min_peaks * 1.5
starts = np.maximum(xgram.shape[0] / 2 - window_width, 0)
ends = np.minimum(xgram.shape[0] / 2 + window_width, xgram.shape[0] - 1)
ind = 0
xgram_window = xgram[int(starts[0]) : int(ends[0] + 1)]
xgram_sum = xgram_window.sum()
window_size = xgram_window.shape[0] * bin_size_s
while (xgram_sum < (min_xcorr_rate * window_size * 10)) and (ind < starts.shape[0]):
xgram_window = xgram[int(starts[ind]) : int(ends[ind] + 1)]
xgram_sum = xgram_window.sum()
window_size = xgram_window.shape[0] * bin_size_s
ind += 1
# use the whole ccg if peak finding fails
if ind == starts.shape[0]:
xgram_window = xgram

# TODO: was getting error messges when xgram_window was all zero. Why was this happening?
if np.abs(xgram_window).sum() == 0:
return 0

sig = (
wasserstein_distance(
np.arange(xgram_window.shape[0]) / xgram_window.shape[0],
np.arange(xgram_window.shape[0]) / xgram_window.shape[0],
xgram_window,
np.ones_like(xgram_window),
)
* 4
)

if xgram_window.sum() < (min_xcorr_rate * window_size):
sig *= (xgram_window.sum() / (min_xcorr_rate * window_size)) ** 2

# if sig < 0.04 and xgram_window.sum() < (min_xcorr_rate * window_size):
if xgram_window.sum() < (min_xcorr_rate / 4 * window_size):
sig = -4 # don't merge if the event count is way too low

return sig


def _sliding_RP_viol_pair(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably...

correlogram,
bin_size_ms: float,
accept_threshold: float = 0.15,
) -> float:
"""
Calculate the sliding refractory period violation confidence for a cluster.

Ported from https://github.com/saikoukunt/SLAy.

Parameters
----------
correlogram : np.array
The auto-correlogram of the cluster.
bin_size_ms : float
The width in ms of the bin size of the input ccgs.
accept_threshold : float, default: 0.15
The minimum ccg firing rate in Hz.

Returns
-------
sig : float
The refractory period violation confidence for the cluster.
"""
from scipy.signal import butter, sosfiltfilt
from scipy.stats import poisson

# create various refractory periods sizes to test (between 0 and 20x bin size)
all_refractory_periods = np.arange(0, 21 * bin_size_ms, bin_size_ms) / 1000
test_refractory_period_indices = np.array([1, 2, 4, 6, 8, 12, 16, 20], dtype="int8")
test_refractory_periods = [
all_refractory_periods[test_rp_index] for test_rp_index in test_refractory_period_indices
]

# calculate and avg halves of acg to ensure symmetry
# keep only second half of acg, refractory period violations are compared from the center of acg
half_len = int(correlogram.shape[0] / 2)
correlogram = (correlogram[half_len:] + correlogram[:half_len][::-1]) / 2

acg_cumsum = np.cumsum(correlogram)
sum_res = acg_cumsum[test_refractory_period_indices - 1] # -1 bc 0th bin corresponds to 0-bin_size ms

# low-pass filter acg and use max as baseline event rate
order = 4 # Hz
cutoff_freq = 250 # Hz
fs = 1 / bin_size_ms * 1000
nyqist = fs / 2
cutoff = cutoff_freq / nyqist
sos = butter(order, cutoff, btype="low", output="sos")
smoothed_acg = sosfiltfilt(sos, correlogram)

bin_rate_max = np.max(smoothed_acg)
max_conts_max = np.array(test_refractory_periods) / bin_size_ms * 1000 * (bin_rate_max * accept_threshold)
# compute confidence of less than acceptThresh contamination at each refractory period
confs = 1 - poisson.cdf(sum_res, max_conts_max)
rp_viol = 1 - confs.max()

return rp_viol