diff --git a/doc/references.rst b/doc/references.rst index 1179afa509..59aa590093 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -88,7 +88,10 @@ important for your research: Curation Module --------------- -If you use the :code:`get_potential_auto_merge` method from the curation module, please cite [Llobet]_ + +If you use the default "similarity_correlograms" preset in the :code:`compute_merge_unit_groups` method from the curation module, please cite [Llobet]_ + +If you use the "slay" preset in the :code:`compute_merge_unit_groups` method, please cite [Koukuntla]_ If you use :code:`auto_label_units` or :code:`train_model`, please cite [Jain]_ @@ -140,6 +143,8 @@ References .. [Jia] `High-density extracellular probes reveal dendritic backpropagation and facilitate neuron classification. 2019 `_ +.. [Koukuntla] `SLAy-ing oversplitting errors in high-density electrophysiology spike sorting. 2025. `_ + .. [Lee] `YASS: Yet another spike sorter. 2017. `_ .. [Lemon] Methods for neuronal recording in conscious animals. IBRO Handbook Series. 1984. diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 2d078c4d28..bc330dbb98 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -52,6 +52,10 @@ "knn", "quality_score", ], + "slay": [ + "template_similarity", + "slay_score", + ], } _required_extensions = { @@ -60,6 +64,7 @@ "snr": ["templates", "noise_levels"], "template_similarity": ["templates", "template_similarity"], "knn": ["templates", "spike_locations", "spike_amplitudes"], + "slay_score": ["correlograms", "template_similarity"], } @@ -84,6 +89,7 @@ "censored_period_ms": 0.3, }, "quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + "slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5}, } @@ -356,6 +362,14 @@ def compute_merge_unit_groups( ) outs["pairs_decreased_score"] = pairs_decreased_score + elif step == "slay_score": + + M_ij = compute_slay_matrix( + sorting_analyzer, params["k1"], params["k2"], templates_diff=outs["templates_diff"], pair_mask=pair_mask + ) + + pair_mask = pair_mask & (M_ij > params["slay_threshold"]) + # FINAL STEP : create the final list from pair_mask boolean matrix ind1, ind2 = np.nonzero(pair_mask) merge_unit_groups = list(zip(unit_ids[ind1], unit_ids[ind2])) @@ -1506,3 +1520,251 @@ def estimate_cross_contamination( ) return estimation, p_value + + +def compute_slay_matrix( + sorting_analyzer: SortingAnalyzer, + k1: float, + k2: float, + templates_diff: np.ndarray | None, + pair_mask: np.ndarray | None = None, +): + """ + Computes the "merge decision metric" from the SLAy method, made from combining + a template similarity measure, a cross-correlation significance measure and a + sliding refractory period violation measure. A large M suggests that two + units should be merged. + + Paramters + --------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting data + k1 : float + Coefficient determining the importance of the cross-correlation significance + k2 : float + Coefficient determining the importance of the sliding rp violation + templates_diff : np.ndarray | None + Pre-computed template similarity difference matrix. If None, it will be retrieved from the sorting_analyzer. + pair_mask : None | np.ndarray, default: None + A bool matrix describing which pairs are possible merges based on previous steps + + + References + ---------- + Based on computation originally implemented in SLAy [Koukuntla]_. + + Implementation is based on one of the original implementations written by Sai Koukuntla, + found at https://github.com/saikoukunt/SLAy. + """ + + num_units = sorting_analyzer.get_num_units() + + if pair_mask is None: + pair_mask = np.triu(np.arange(num_units), 1) > 0 + + if templates_diff is not None: + sigma_ij = 1 - templates_diff + else: + sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data() + rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask) + + M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij + + return M_ij + + +def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarray): + """ + Computes a cross-correlation significance measure and a sliding refractory period violation + measure for all units in the `sorting_analyzer`. + + Paramters + --------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting data + pair_mask : np.ndarray + A bool matrix describing which pairs are possible merges based on previous steps + """ + + correlograms_extension = sorting_analyzer.get_extension("correlograms") + ccgs, _ = correlograms_extension.get_data() + + # convert to seconds for SLAy functions + bin_size_ms = correlograms_extension.params["bin_ms"] + + rho_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) + eta_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) + + for unit_index_1, _ in enumerate(sorting_analyzer.unit_ids): + for unit_index_2, _ in enumerate(sorting_analyzer.unit_ids): + + # Don't waste time computing the other metrics if units not candidates merges + if not pair_mask[unit_index_1, unit_index_2]: + continue + + xgram = ccgs[unit_index_1, unit_index_2, :] + + rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair( + xgram, bin_size_s=bin_size_ms / 1000, min_xcorr_rate=0 + ) + eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(xgram, bin_size_ms=bin_size_ms) + + return rho_ij, eta_ij + + +def _compute_xcorr_pair( + xgram, + bin_size_s: float, + min_xcorr_rate: float, +) -> float: + """ + Calculates a cross-correlation significance metric for a cluster pair. + + Uses the wasserstein distance between an observed cross-correlogram and a null + distribution as an estimate of how significant the dependence between + two neurons is. Low spike count cross-correlograms have large wasserstein + distances from null by chance, so we first try to expand the window size. If + that fails to yield enough spikes, we apply a penalty to the metric. + + Ported from https://github.com/saikoukunt/SLAy. + + Parameters + ---------- + xgram : np.array + The raw cross-correlogram for the cluster pair. + bin_size_s : float + The width in seconds of the bin size of the input ccgs. + min_xcorr_rate : float + The minimum ccg firing rate in Hz. + + Returns + ------- + sig : float + The calculated cross-correlation significance metric. + """ + + from scipy.signal import butter, find_peaks_cwt, sosfiltfilt + from scipy.stats import wasserstein_distance + + # calculate low-pass filtered second derivative of ccg + fs = 1 / bin_size_s + cutoff_freq = 100 + nyqist = fs / 2 + cutoff = cutoff_freq / nyqist + peak_width = 0.002 / bin_size_s + + xgram_2d = np.diff(xgram, 2) + sos = butter(4, cutoff, output="sos") + xgram_2d = sosfiltfilt(sos, xgram_2d) + + if xgram.sum() == 0: + return 0 + + # find negative peaks of second derivative of ccg, these are the edges of dips in ccg + peaks = find_peaks_cwt(-xgram_2d, peak_width, noise_perc=90) + 1 + # if no peaks are found, return a very low significance + if peaks.shape[0] == 0: + return -4 + peaks = np.abs(peaks - xgram.shape[0] / 2) + peaks = peaks[peaks > 0.5 * peak_width] + min_peaks = np.sort(peaks) + + # start with peaks closest to 0 and move to the next set of peaks if the event count is too low + window_width = min_peaks * 1.5 + starts = np.maximum(xgram.shape[0] / 2 - window_width, 0) + ends = np.minimum(xgram.shape[0] / 2 + window_width, xgram.shape[0] - 1) + ind = 0 + xgram_window = xgram[int(starts[0]) : int(ends[0] + 1)] + xgram_sum = xgram_window.sum() + window_size = xgram_window.shape[0] * bin_size_s + while (xgram_sum < (min_xcorr_rate * window_size * 10)) and (ind < starts.shape[0]): + xgram_window = xgram[int(starts[ind]) : int(ends[ind] + 1)] + xgram_sum = xgram_window.sum() + window_size = xgram_window.shape[0] * bin_size_s + ind += 1 + # use the whole ccg if peak finding fails + if ind == starts.shape[0]: + xgram_window = xgram + + # TODO: was getting error messges when xgram_window was all zero. Why was this happening? + if np.abs(xgram_window).sum() == 0: + return 0 + + sig = ( + wasserstein_distance( + np.arange(xgram_window.shape[0]) / xgram_window.shape[0], + np.arange(xgram_window.shape[0]) / xgram_window.shape[0], + xgram_window, + np.ones_like(xgram_window), + ) + * 4 + ) + + if xgram_window.sum() < (min_xcorr_rate * window_size): + sig *= (xgram_window.sum() / (min_xcorr_rate * window_size)) ** 2 + + # if sig < 0.04 and xgram_window.sum() < (min_xcorr_rate * window_size): + if xgram_window.sum() < (min_xcorr_rate / 4 * window_size): + sig = -4 # don't merge if the event count is way too low + + return sig + + +def _sliding_RP_viol_pair( + correlogram, + bin_size_ms: float, + accept_threshold: float = 0.15, +) -> float: + """ + Calculate the sliding refractory period violation confidence for a cluster. + + Ported from https://github.com/saikoukunt/SLAy. + + Parameters + ---------- + correlogram : np.array + The auto-correlogram of the cluster. + bin_size_ms : float + The width in ms of the bin size of the input ccgs. + accept_threshold : float, default: 0.15 + The minimum ccg firing rate in Hz. + + Returns + ------- + sig : float + The refractory period violation confidence for the cluster. + """ + from scipy.signal import butter, sosfiltfilt + from scipy.stats import poisson + + # create various refractory periods sizes to test (between 0 and 20x bin size) + all_refractory_periods = np.arange(0, 21 * bin_size_ms, bin_size_ms) / 1000 + test_refractory_period_indices = np.array([1, 2, 4, 6, 8, 12, 16, 20], dtype="int8") + test_refractory_periods = [ + all_refractory_periods[test_rp_index] for test_rp_index in test_refractory_period_indices + ] + + # calculate and avg halves of acg to ensure symmetry + # keep only second half of acg, refractory period violations are compared from the center of acg + half_len = int(correlogram.shape[0] / 2) + correlogram = (correlogram[half_len:] + correlogram[:half_len][::-1]) / 2 + + acg_cumsum = np.cumsum(correlogram) + sum_res = acg_cumsum[test_refractory_period_indices - 1] # -1 bc 0th bin corresponds to 0-bin_size ms + + # low-pass filter acg and use max as baseline event rate + order = 4 # Hz + cutoff_freq = 250 # Hz + fs = 1 / bin_size_ms * 1000 + nyqist = fs / 2 + cutoff = cutoff_freq / nyqist + sos = butter(order, cutoff, btype="low", output="sos") + smoothed_acg = sosfiltfilt(sos, correlogram) + + bin_rate_max = np.max(smoothed_acg) + max_conts_max = np.array(test_refractory_periods) / bin_size_ms * 1000 * (bin_rate_max * accept_threshold) + # compute confidence of less than acceptThresh contamination at each refractory period + confs = 1 - poisson.cdf(sum_res, max_conts_max) + rp_viol = 1 - confs.max() + + return rp_viol