From 6441f50fe53ab79a8411f5564d555f35aefb7c76 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 31 Oct 2025 09:07:55 +0000 Subject: [PATCH 1/7] add initial slay structure in compute_merge_unit_groups --- src/spikeinterface/curation/auto_merge.py | 29 +++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 2d078c4d28..e30149e602 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -52,6 +52,10 @@ "knn", "quality_score", ], + "slay": [ + "template_similarity", + "slay_score", + ], } _required_extensions = { @@ -60,6 +64,7 @@ "snr": ["templates", "noise_levels"], "template_similarity": ["templates", "template_similarity"], "knn": ["templates", "spike_locations", "spike_amplitudes"], + "slay_score": ["correlograms"], } @@ -84,6 +89,7 @@ "censored_period_ms": 0.3, }, "quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + "slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5}, } @@ -356,6 +362,16 @@ def compute_merge_unit_groups( ) outs["pairs_decreased_score"] = pairs_decreased_score + elif step == "slay_score": + + M_ij = compute_slay_matrix( + sorting_analyzer, + params["k1"], + params["k2"], + ) + + pair_mask = M_ij > params["slay_threshold"] + # FINAL STEP : create the final list from pair_mask boolean matrix ind1, ind2 = np.nonzero(pair_mask) merge_unit_groups = list(zip(unit_ids[ind1], unit_ids[ind2])) @@ -1506,3 +1522,16 @@ def estimate_cross_contamination( ) return estimation, p_value + + +def compute_slay_matrix(sorting_analyzer: SortingAnalyzer, k1: float, k2: float): + + from numpy import random + + sigma_ij = random.rand(len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)) + rho_ij = random.rand(len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)) + eta_ij = random.rand(len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)) + + M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij + + return M_ij From 9ff3ef01d94fb93186179bf07307a45141e35b19 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 31 Oct 2025 10:29:43 +0000 Subject: [PATCH 2/7] add all scores --- src/spikeinterface/curation/auto_merge.py | 186 ++++++++++++++++++++-- 1 file changed, 174 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index e30149e602..b533c42126 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -89,7 +89,7 @@ "censored_period_ms": 0.3, }, "quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, - "slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5}, + "slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5, "template_diff_thresh": 0.25}, } @@ -364,11 +364,7 @@ def compute_merge_unit_groups( elif step == "slay_score": - M_ij = compute_slay_matrix( - sorting_analyzer, - params["k1"], - params["k2"], - ) + M_ij = compute_slay_matrix(sorting_analyzer, params["k1"], params["k2"], params["template_diff_thresh"]) pair_mask = M_ij > params["slay_threshold"] @@ -1524,14 +1520,180 @@ def estimate_cross_contamination( return estimation, p_value -def compute_slay_matrix(sorting_analyzer: SortingAnalyzer, k1: float, k2: float): +def compute_slay_matrix(sorting_analyzer: SortingAnalyzer, k1: float, k2: float, template_diff_thresh: float): - from numpy import random - - sigma_ij = random.rand(len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)) - rho_ij = random.rand(len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)) - eta_ij = random.rand(len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)) + sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data() + rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, template_diff_thresh) M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij return M_ij + + +def compute_xcorr_and_rp(sorting_analyzer, template_diff_thresh): + + correlograms_extension = sorting_analyzer.get_extension("correlograms") + template_similarity = sorting_analyzer.get_extension("template_similarity").get_data() + + ccgs, _ = correlograms_extension.get_data() + xcorr_bin_width = correlograms_extension.params["bin_ms"] / 1000 + + rho_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) + eta_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) + + for unit_index_1, _ in enumerate(sorting_analyzer.unit_ids): + for unit_index_2, _ in enumerate(sorting_analyzer.unit_ids): + + # Don't waste time computing the other metrics if we fail the template similarity check + if template_similarity[unit_index_1, unit_index_2] < 1 - template_diff_thresh: + continue + + xgram = ccgs[unit_index_1, unit_index_2, :] + + rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair( + xgram, xcorr_bin_width=xcorr_bin_width, min_xcorr_rate=0 + ) + eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(xgram, bin_size=xcorr_bin_width) + + return rho_ij, eta_ij + + +def _compute_xcorr_pair( + xgram, + xcorr_bin_width: float, + min_xcorr_rate: float, +) -> float: + """ + Calculates a cross-correlation significance metric for a cluster pair. + + Uses the wasserstein distance between an observed cross-correlogram and a null + distribution as an estimate of how significant the dependence between + two neurons is. Low spike count cross-correlograms have large wasserstein + distances from null by chance, so we first try to expand the window size. If + that fails to yield enough spikes, we apply a penalty to the metric. + + Args: + xgram (NDArray): The raw cross-correlogram for the cluster pair. + xcorr_bin_width (float): The width in seconds of the bin size of the + input ccgs. + max_window (float): The largest allowed window size during window + expansion. + min_xcorr_rate (float): The minimum ccg firing rate in Hz. + + Returns: + sig (float): The calculated cross-correlation significance metric. + """ + + from scipy.signal import butter, find_peaks_cwt, sosfiltfilt + from scipy.stats import wasserstein_distance + + # calculate low-pass filtered second derivative of ccg + fs = 1 / xcorr_bin_width + cutoff_freq = 100 + nyqist = fs / 2 + cutoff = cutoff_freq / nyqist + peak_width = 0.002 / xcorr_bin_width + + xgram_2d = np.diff(xgram, 2) + sos = butter(4, cutoff, output="sos") + xgram_2d = sosfiltfilt(sos, xgram_2d) + + if xgram.sum() == 0: + return 0 + + # find negative peaks of second derivative of ccg, these are the edges of dips in ccg + peaks = find_peaks_cwt(-xgram_2d, peak_width, noise_perc=90) + 1 + # if no peaks are found, return a very low significance + if peaks.shape[0] == 0: + return -4 + peaks = np.abs(peaks - xgram.shape[0] / 2) + peaks = peaks[peaks > 0.5 * peak_width] + min_peaks = np.sort(peaks) + + # start with peaks closest to 0 and move to the next set of peaks if the event count is too low + window_width = min_peaks * 1.5 + starts = np.maximum(xgram.shape[0] / 2 - window_width, 0) + ends = np.minimum(xgram.shape[0] / 2 + window_width, xgram.shape[0] - 1) + ind = 0 + xgram_window = xgram[int(starts[0]) : int(ends[0] + 1)] + xgram_sum = xgram_window.sum() + window_size = xgram_window.shape[0] * xcorr_bin_width + while (xgram_sum < (min_xcorr_rate * window_size * 10)) and (ind < starts.shape[0]): + xgram_window = xgram[int(starts[ind]) : int(ends[ind] + 1)] + xgram_sum = xgram_window.sum() + window_size = xgram_window.shape[0] * xcorr_bin_width + ind += 1 + # use the whole ccg if peak finding fails + if ind == starts.shape[0]: + xgram_window = xgram + + # TODO: was getting error messges when xgram_window was all zero. Why was this happening? + if np.abs(xgram_window).sum() == 0: + return 0 + + sig = ( + wasserstein_distance( + np.arange(xgram_window.shape[0]) / xgram_window.shape[0], + np.arange(xgram_window.shape[0]) / xgram_window.shape[0], + xgram_window, + np.ones_like(xgram_window), + ) + * 4 + ) + + if xgram_window.sum() < (min_xcorr_rate * window_size): + sig *= (xgram_window.sum() / (min_xcorr_rate * window_size)) ** 2 + + # if sig < 0.04 and xgram_window.sum() < (min_xcorr_rate * window_size): + if xgram_window.sum() < (min_xcorr_rate / 4 * window_size): + sig = -4 # don't merge if the event count is way too low + + return sig + + +def _sliding_RP_viol_pair( + correlogram, + bin_size: float = 1, + acceptThresh: float = 0.15, +) -> float: + """ + Calculate the sliding refractory period violation confidence for a cluster. + Args: + correlogram (NDArray): The auto-correlogram of the cluster. + bin_size (float, optional): The size of each bin in ms. Defaults to 0.25. + acceptThresh (float, optional): The threshold for accepting refractory period violations. Defaults to 0.1. + Returns: + float: The refractory period violation confidence for the cluster. + """ + from scipy.signal import butter, sosfiltfilt + from scipy.stats import poisson + + # create various refractory periods sizes to test (between 0 and 20x bin size) + b = np.arange(0, 21 * bin_size, bin_size) / 1000 + bTestIdx = np.array([1, 2, 4, 6, 8, 12, 16, 20], dtype="int8") + bTest = [b[i] for i in bTestIdx] + + # calculate and avg halves of acg to ensure symmetry + # keep only second half of acg, refractory period violations are compared from the center of acg + half_len = int(correlogram.shape[0] / 2) + correlogram = (correlogram[half_len:] + correlogram[:half_len][::-1]) / 2 + + acg_cumsum = np.cumsum(correlogram) + sum_res = acg_cumsum[bTestIdx - 1] # -1 bc 0th bin corresponds to 0-bin_size ms + + # low-pass filter acg and use max as baseline event rate + order = 4 # Hz + cutoff_freq = 250 # Hz + fs = 1 / bin_size * 1000 + nyqist = fs / 2 + cutoff = cutoff_freq / nyqist + sos = butter(order, cutoff, btype="low", output="sos") + smoothed_acg = sosfiltfilt(sos, correlogram) + + bin_rate_max = np.max(smoothed_acg) + max_conts_max = np.array(bTest) / bin_size * 1000 * (bin_rate_max * acceptThresh) + # compute confidence of less than acceptThresh contamination at each refractory period + confs = 1 - poisson.cdf(sum_res, max_conts_max) + rp_viol = 1 - confs.max() + + return rp_viol From 8ed4d811159bdbdcddea66ba69b55989933f8531 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 31 Oct 2025 10:45:58 +0000 Subject: [PATCH 3/7] add docs --- doc/references.rst | 7 +- src/spikeinterface/curation/auto_merge.py | 95 +++++++++++++++++------ 2 files changed, 79 insertions(+), 23 deletions(-) diff --git a/doc/references.rst b/doc/references.rst index 1179afa509..59aa590093 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -88,7 +88,10 @@ important for your research: Curation Module --------------- -If you use the :code:`get_potential_auto_merge` method from the curation module, please cite [Llobet]_ + +If you use the default "similarity_correlograms" preset in the :code:`compute_merge_unit_groups` method from the curation module, please cite [Llobet]_ + +If you use the "slay" preset in the :code:`compute_merge_unit_groups` method, please cite [Koukuntla]_ If you use :code:`auto_label_units` or :code:`train_model`, please cite [Jain]_ @@ -140,6 +143,8 @@ References .. [Jia] `High-density extracellular probes reveal dendritic backpropagation and facilitate neuron classification. 2019 `_ +.. [Koukuntla] `SLAy-ing oversplitting errors in high-density electrophysiology spike sorting. 2025. `_ + .. [Lee] `YASS: Yet another spike sorter. 2017. `_ .. [Lemon] Methods for neuronal recording in conscious animals. IBRO Handbook Series. 1984. diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index b533c42126..65660fa6b8 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -1521,6 +1521,31 @@ def estimate_cross_contamination( def compute_slay_matrix(sorting_analyzer: SortingAnalyzer, k1: float, k2: float, template_diff_thresh: float): + """ + Computes the "merge decision metric" from the SLAy method, made from combining + a template similarity measure, a cross-correlation significance measure and a + sliding refractory period violation measure. A large M suggests that two + units should be merged. + + Paramters + --------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting data + k1 : float + Coefficient determining the importance of the cross-correlation significance + k2 : float + Coefficient determining the importance of the sliding rp violation + template_diff_thresh : float + Threshold for how different template similarities can be to be considered for merging + + + References + ---------- + Based on computation originally implemented in SLAy [Koukuntla]_. + + Implementation is based on one of the original implementations written by Sai Koukuntla, + found at https://github.com/saikoukunt/SLAy. + """ sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data() rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, template_diff_thresh) @@ -1530,13 +1555,26 @@ def compute_slay_matrix(sorting_analyzer: SortingAnalyzer, k1: float, k2: float, return M_ij -def compute_xcorr_and_rp(sorting_analyzer, template_diff_thresh): +def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, template_diff_thresh: float): + """ + Computes a cross-correlation significance measure and a sliding refractory period violation + measure for all units in the `sorting_analyzer`. + + Paramters + --------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting data + template_diff_thresh : float + Threshold for how different template similarities can be to be considered for merging + """ correlograms_extension = sorting_analyzer.get_extension("correlograms") template_similarity = sorting_analyzer.get_extension("template_similarity").get_data() ccgs, _ = correlograms_extension.get_data() - xcorr_bin_width = correlograms_extension.params["bin_ms"] / 1000 + + # convert to seconds for SLAy functions + bin_s = correlograms_extension.params["bin_ms"] / 1000 rho_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) eta_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) @@ -1550,10 +1588,8 @@ def compute_xcorr_and_rp(sorting_analyzer, template_diff_thresh): xgram = ccgs[unit_index_1, unit_index_2, :] - rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair( - xgram, xcorr_bin_width=xcorr_bin_width, min_xcorr_rate=0 - ) - eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(xgram, bin_size=xcorr_bin_width) + rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair(xgram, xcorr_bin_width=bin_s, min_xcorr_rate=0) + eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(xgram, bin_size=bin_s) return rho_ij, eta_ij @@ -1572,16 +1608,21 @@ def _compute_xcorr_pair( distances from null by chance, so we first try to expand the window size. If that fails to yield enough spikes, we apply a penalty to the metric. - Args: - xgram (NDArray): The raw cross-correlogram for the cluster pair. - xcorr_bin_width (float): The width in seconds of the bin size of the - input ccgs. - max_window (float): The largest allowed window size during window - expansion. - min_xcorr_rate (float): The minimum ccg firing rate in Hz. + Ported from https://github.com/saikoukunt/SLAy. + + Parameters + ---------- + xgram : np.array + The raw cross-correlogram for the cluster pair. + xcorr_bin_width : float + The width in seconds of the bin size of the input ccgs. + min_xcorr_rate : float + The minimum ccg firing rate in Hz. - Returns: - sig (float): The calculated cross-correlation significance metric. + Returns + ------- + sig : float + The calculated cross-correlation significance metric. """ from scipy.signal import butter, find_peaks_cwt, sosfiltfilt @@ -1653,17 +1694,27 @@ def _compute_xcorr_pair( def _sliding_RP_viol_pair( correlogram, - bin_size: float = 1, + bin_size: float, acceptThresh: float = 0.15, ) -> float: """ Calculate the sliding refractory period violation confidence for a cluster. - Args: - correlogram (NDArray): The auto-correlogram of the cluster. - bin_size (float, optional): The size of each bin in ms. Defaults to 0.25. - acceptThresh (float, optional): The threshold for accepting refractory period violations. Defaults to 0.1. - Returns: - float: The refractory period violation confidence for the cluster. + + Ported from https://github.com/saikoukunt/SLAy. + + Parameters + ---------- + correlogram : np.array + The auto-correlogram of the cluster. + bin_size : float + The width in ms of the bin size of the input ccgs. + acceptThresh : float, default: 0.15 + The minimum ccg firing rate in Hz. + + Returns + ------- + sig : float + The refractory period violation confidence for the cluster. """ from scipy.signal import butter, sosfiltfilt from scipy.stats import poisson From 45656ce94e735b969b69d726d02a274fc1650b79 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 31 Oct 2025 11:02:51 +0000 Subject: [PATCH 4/7] add template_similarity to slay_score requirements --- src/spikeinterface/curation/auto_merge.py | 32 ++++++++++++----------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 65660fa6b8..ba7a744524 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -64,7 +64,7 @@ "snr": ["templates", "noise_levels"], "template_similarity": ["templates", "template_similarity"], "knn": ["templates", "spike_locations", "spike_amplitudes"], - "slay_score": ["correlograms"], + "slay_score": ["correlograms", "template_similarity"], } @@ -1574,7 +1574,7 @@ def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, template_diff_thresh ccgs, _ = correlograms_extension.get_data() # convert to seconds for SLAy functions - bin_s = correlograms_extension.params["bin_ms"] / 1000 + bin_size_ms = correlograms_extension.params["bin_ms"] rho_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) eta_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) @@ -1588,15 +1588,17 @@ def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, template_diff_thresh xgram = ccgs[unit_index_1, unit_index_2, :] - rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair(xgram, xcorr_bin_width=bin_s, min_xcorr_rate=0) - eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(xgram, bin_size=bin_s) + rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair( + xgram, bin_size_s=bin_size_ms / 1000, min_xcorr_rate=0 + ) + eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(xgram, bin_size_ms=bin_size_ms) return rho_ij, eta_ij def _compute_xcorr_pair( xgram, - xcorr_bin_width: float, + bin_size_s: float, min_xcorr_rate: float, ) -> float: """ @@ -1614,7 +1616,7 @@ def _compute_xcorr_pair( ---------- xgram : np.array The raw cross-correlogram for the cluster pair. - xcorr_bin_width : float + bin_size_s : float The width in seconds of the bin size of the input ccgs. min_xcorr_rate : float The minimum ccg firing rate in Hz. @@ -1629,11 +1631,11 @@ def _compute_xcorr_pair( from scipy.stats import wasserstein_distance # calculate low-pass filtered second derivative of ccg - fs = 1 / xcorr_bin_width + fs = 1 / bin_size_s cutoff_freq = 100 nyqist = fs / 2 cutoff = cutoff_freq / nyqist - peak_width = 0.002 / xcorr_bin_width + peak_width = 0.002 / bin_size_s xgram_2d = np.diff(xgram, 2) sos = butter(4, cutoff, output="sos") @@ -1658,11 +1660,11 @@ def _compute_xcorr_pair( ind = 0 xgram_window = xgram[int(starts[0]) : int(ends[0] + 1)] xgram_sum = xgram_window.sum() - window_size = xgram_window.shape[0] * xcorr_bin_width + window_size = xgram_window.shape[0] * bin_size_s while (xgram_sum < (min_xcorr_rate * window_size * 10)) and (ind < starts.shape[0]): xgram_window = xgram[int(starts[ind]) : int(ends[ind] + 1)] xgram_sum = xgram_window.sum() - window_size = xgram_window.shape[0] * xcorr_bin_width + window_size = xgram_window.shape[0] * bin_size_s ind += 1 # use the whole ccg if peak finding fails if ind == starts.shape[0]: @@ -1694,7 +1696,7 @@ def _compute_xcorr_pair( def _sliding_RP_viol_pair( correlogram, - bin_size: float, + bin_size_ms: float, acceptThresh: float = 0.15, ) -> float: """ @@ -1706,7 +1708,7 @@ def _sliding_RP_viol_pair( ---------- correlogram : np.array The auto-correlogram of the cluster. - bin_size : float + bin_size_ms : float The width in ms of the bin size of the input ccgs. acceptThresh : float, default: 0.15 The minimum ccg firing rate in Hz. @@ -1720,7 +1722,7 @@ def _sliding_RP_viol_pair( from scipy.stats import poisson # create various refractory periods sizes to test (between 0 and 20x bin size) - b = np.arange(0, 21 * bin_size, bin_size) / 1000 + b = np.arange(0, 21 * bin_size_ms, bin_size_ms) / 1000 bTestIdx = np.array([1, 2, 4, 6, 8, 12, 16, 20], dtype="int8") bTest = [b[i] for i in bTestIdx] @@ -1735,14 +1737,14 @@ def _sliding_RP_viol_pair( # low-pass filter acg and use max as baseline event rate order = 4 # Hz cutoff_freq = 250 # Hz - fs = 1 / bin_size * 1000 + fs = 1 / bin_size_ms * 1000 nyqist = fs / 2 cutoff = cutoff_freq / nyqist sos = butter(order, cutoff, btype="low", output="sos") smoothed_acg = sosfiltfilt(sos, correlogram) bin_rate_max = np.max(smoothed_acg) - max_conts_max = np.array(bTest) / bin_size * 1000 * (bin_rate_max * acceptThresh) + max_conts_max = np.array(bTest) / bin_size_ms * 1000 * (bin_rate_max * acceptThresh) # compute confidence of less than acceptThresh contamination at each refractory period confs = 1 - poisson.cdf(sum_res, max_conts_max) rp_viol = 1 - confs.max() From 3994547eb30382afea208c84190357b6cdf00935 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 31 Oct 2025 12:55:33 +0000 Subject: [PATCH 5/7] respond to alessio --- src/spikeinterface/curation/auto_merge.py | 45 +++++++++++++---------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index ba7a744524..a1b7beffb4 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -364,9 +364,9 @@ def compute_merge_unit_groups( elif step == "slay_score": - M_ij = compute_slay_matrix(sorting_analyzer, params["k1"], params["k2"], params["template_diff_thresh"]) + M_ij = compute_slay_matrix(sorting_analyzer, params["k1"], params["k2"], pair_mask) - pair_mask = M_ij > params["slay_threshold"] + pair_mask = pair_mask & (M_ij > params["slay_threshold"]) # FINAL STEP : create the final list from pair_mask boolean matrix ind1, ind2 = np.nonzero(pair_mask) @@ -1520,7 +1520,7 @@ def estimate_cross_contamination( return estimation, p_value -def compute_slay_matrix(sorting_analyzer: SortingAnalyzer, k1: float, k2: float, template_diff_thresh: float): +def compute_slay_matrix(sorting_analyzer: SortingAnalyzer, k1: float, k2: float, pair_mask=None): """ Computes the "merge decision metric" from the SLAy method, made from combining a template similarity measure, a cross-correlation significance measure and a @@ -1535,8 +1535,8 @@ def compute_slay_matrix(sorting_analyzer: SortingAnalyzer, k1: float, k2: float, Coefficient determining the importance of the cross-correlation significance k2 : float Coefficient determining the importance of the sliding rp violation - template_diff_thresh : float - Threshold for how different template similarities can be to be considered for merging + pair_mask : None | np.ndarray, default: None + A bool matrix describing which pairs are possible merges based on previous steps References @@ -1547,15 +1547,20 @@ def compute_slay_matrix(sorting_analyzer: SortingAnalyzer, k1: float, k2: float, found at https://github.com/saikoukunt/SLAy. """ + num_units = sorting_analyzer.get_num_units() + + if pair_mask is None: + pair_mask = np.ones((num_units, num_units), dtype="bool") + sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data() - rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, template_diff_thresh) + rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask) M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij return M_ij -def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, template_diff_thresh: float): +def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarray): """ Computes a cross-correlation significance measure and a sliding refractory period violation measure for all units in the `sorting_analyzer`. @@ -1564,13 +1569,11 @@ def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, template_diff_thresh --------- sorting_analyzer : SortingAnalyzer The sorting analyzer object containing the spike sorting data - template_diff_thresh : float - Threshold for how different template similarities can be to be considered for merging + pair_mask : np.ndarray + A bool matrix describing which pairs are possible merges based on previous steps """ correlograms_extension = sorting_analyzer.get_extension("correlograms") - template_similarity = sorting_analyzer.get_extension("template_similarity").get_data() - ccgs, _ = correlograms_extension.get_data() # convert to seconds for SLAy functions @@ -1582,8 +1585,8 @@ def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, template_diff_thresh for unit_index_1, _ in enumerate(sorting_analyzer.unit_ids): for unit_index_2, _ in enumerate(sorting_analyzer.unit_ids): - # Don't waste time computing the other metrics if we fail the template similarity check - if template_similarity[unit_index_1, unit_index_2] < 1 - template_diff_thresh: + # Don't waste time computing the other metrics if units not candidates merges + if not pair_mask[unit_index_1, unit_index_2]: continue xgram = ccgs[unit_index_1, unit_index_2, :] @@ -1697,7 +1700,7 @@ def _compute_xcorr_pair( def _sliding_RP_viol_pair( correlogram, bin_size_ms: float, - acceptThresh: float = 0.15, + accept_threshold: float = 0.15, ) -> float: """ Calculate the sliding refractory period violation confidence for a cluster. @@ -1710,7 +1713,7 @@ def _sliding_RP_viol_pair( The auto-correlogram of the cluster. bin_size_ms : float The width in ms of the bin size of the input ccgs. - acceptThresh : float, default: 0.15 + accept_threshold : float, default: 0.15 The minimum ccg firing rate in Hz. Returns @@ -1722,9 +1725,11 @@ def _sliding_RP_viol_pair( from scipy.stats import poisson # create various refractory periods sizes to test (between 0 and 20x bin size) - b = np.arange(0, 21 * bin_size_ms, bin_size_ms) / 1000 - bTestIdx = np.array([1, 2, 4, 6, 8, 12, 16, 20], dtype="int8") - bTest = [b[i] for i in bTestIdx] + all_refractory_periods = np.arange(0, 21 * bin_size_ms, bin_size_ms) / 1000 + test_refractory_period_indices = np.array([1, 2, 4, 6, 8, 12, 16, 20], dtype="int8") + test_refractory_periods = [ + all_refractory_periods[test_rp_index] for test_rp_index in test_refractory_period_indices + ] # calculate and avg halves of acg to ensure symmetry # keep only second half of acg, refractory period violations are compared from the center of acg @@ -1732,7 +1737,7 @@ def _sliding_RP_viol_pair( correlogram = (correlogram[half_len:] + correlogram[:half_len][::-1]) / 2 acg_cumsum = np.cumsum(correlogram) - sum_res = acg_cumsum[bTestIdx - 1] # -1 bc 0th bin corresponds to 0-bin_size ms + sum_res = acg_cumsum[test_refractory_period_indices - 1] # -1 bc 0th bin corresponds to 0-bin_size ms # low-pass filter acg and use max as baseline event rate order = 4 # Hz @@ -1744,7 +1749,7 @@ def _sliding_RP_viol_pair( smoothed_acg = sosfiltfilt(sos, correlogram) bin_rate_max = np.max(smoothed_acg) - max_conts_max = np.array(bTest) / bin_size_ms * 1000 * (bin_rate_max * acceptThresh) + max_conts_max = np.array(test_refractory_periods) / bin_size_ms * 1000 * (bin_rate_max * accept_threshold) # compute confidence of less than acceptThresh contamination at each refractory period confs = 1 - poisson.cdf(sum_res, max_conts_max) rp_viol = 1 - confs.max() From 75718226520deba55f0bff3f2955c18ca83f0ee5 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 3 Nov 2025 08:38:34 +0000 Subject: [PATCH 6/7] make default pair mask triu --- src/spikeinterface/curation/auto_merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index a1b7beffb4..166fe85ac0 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -1550,7 +1550,7 @@ def compute_slay_matrix(sorting_analyzer: SortingAnalyzer, k1: float, k2: float, num_units = sorting_analyzer.get_num_units() if pair_mask is None: - pair_mask = np.ones((num_units, num_units), dtype="bool") + pair_mask = np.triu(np.arange(num_units), 1) > 0 sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data() rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask) From 9c9dac5cc8a7114beb05744f76badb1e86f2c84e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 6 Nov 2025 11:32:02 +0100 Subject: [PATCH 7/7] Propagate precomputed pairwise similarity to slay --- src/spikeinterface/curation/auto_merge.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 166fe85ac0..bc330dbb98 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -89,7 +89,7 @@ "censored_period_ms": 0.3, }, "quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, - "slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5, "template_diff_thresh": 0.25}, + "slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5}, } @@ -364,7 +364,9 @@ def compute_merge_unit_groups( elif step == "slay_score": - M_ij = compute_slay_matrix(sorting_analyzer, params["k1"], params["k2"], pair_mask) + M_ij = compute_slay_matrix( + sorting_analyzer, params["k1"], params["k2"], templates_diff=outs["templates_diff"], pair_mask=pair_mask + ) pair_mask = pair_mask & (M_ij > params["slay_threshold"]) @@ -1520,7 +1522,13 @@ def estimate_cross_contamination( return estimation, p_value -def compute_slay_matrix(sorting_analyzer: SortingAnalyzer, k1: float, k2: float, pair_mask=None): +def compute_slay_matrix( + sorting_analyzer: SortingAnalyzer, + k1: float, + k2: float, + templates_diff: np.ndarray | None, + pair_mask: np.ndarray | None = None, +): """ Computes the "merge decision metric" from the SLAy method, made from combining a template similarity measure, a cross-correlation significance measure and a @@ -1535,6 +1543,8 @@ def compute_slay_matrix(sorting_analyzer: SortingAnalyzer, k1: float, k2: float, Coefficient determining the importance of the cross-correlation significance k2 : float Coefficient determining the importance of the sliding rp violation + templates_diff : np.ndarray | None + Pre-computed template similarity difference matrix. If None, it will be retrieved from the sorting_analyzer. pair_mask : None | np.ndarray, default: None A bool matrix describing which pairs are possible merges based on previous steps @@ -1552,7 +1562,10 @@ def compute_slay_matrix(sorting_analyzer: SortingAnalyzer, k1: float, k2: float, if pair_mask is None: pair_mask = np.triu(np.arange(num_units), 1) > 0 - sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data() + if templates_diff is not None: + sigma_ij = 1 - templates_diff + else: + sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data() rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask) M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij