From dadc15de8d14d7afc3a7ff17b156e5cf98c65b31 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 31 Oct 2025 11:50:56 +0100 Subject: [PATCH] Add BaseSpikeVectorExtension --- .../core/analyzer_extension_core.py | 134 +++++++++++++++++- .../postprocessing/amplitude_scalings.py | 95 ++----------- .../postprocessing/spike_amplitudes.py | 99 +------------ .../postprocessing/spike_locations.py | 90 +----------- 4 files changed, 152 insertions(+), 266 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index fea3f3618e..24b6dca5c7 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -11,12 +11,14 @@ import warnings import numpy as np +from collections import namedtuple -from .sortinganalyzer import AnalyzerExtension, register_result_extension +from .sortinganalyzer import SortingAnalyzer, AnalyzerExtension, register_result_extension from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator from .recording_tools import get_noise_levels from .template import Templates from .sorting_tools import random_spikes_selection +from .job_tools import fix_job_kwargs, split_job_kwargs class ComputeRandomSpikes(AnalyzerExtension): @@ -806,3 +808,133 @@ def _handle_backward_compatibility_on_load(self): register_result_extension(ComputeNoiseLevels) compute_noise_levels = ComputeNoiseLevels.function_factory() + + +class BaseSpikeVectorExtension(AnalyzerExtension): + """ + Base class for spikevector-based extension, where the data is a numpy array with the same + length as the spike vector. + """ + + extension_name = None # to be defined in subclass + need_recording = True + use_nodepipeline = True + need_job_kwargs = True + need_backward_compatibility_on_load = False + nodepipeline_variables = [] # to be defined in subclass + + def _set_params(self, **kwargs): + params = kwargs.copy() + return params + + def _run(self, verbose=False, **job_kwargs): + from spikeinterface.core.node_pipeline import run_node_pipeline + + job_kwargs = fix_job_kwargs(job_kwargs) + nodes = self.get_pipeline_nodes() + data = run_node_pipeline( + self.sorting_analyzer.recording, + nodes, + job_kwargs=job_kwargs, + job_name=self.extension_name, + gather_mode="memory", + verbose=False, + ) + if isinstance(data, tuple): + # this logic enables extensions to optionally compute additional data based on params + assert len(data) <= len(self.nodepipeline_variables), "Pipeline produced more outputs than expected" + else: + data = (data,) + if len(self.nodepipeline_variables) > len(data): + data_names = self.nodepipeline_variables[: len(data)] + else: + data_names = self.nodepipeline_variables + for d, name in zip(data, data_names): + self.data[name] = d + + def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None): + """ + Return extension data. If the extension computes more than one `nodepipeline_variables`, + the `return_data_name` is used to specify which one to return. + + Parameters + ---------- + outputs : "numpy" | "by_unit", default: "numpy" + How to return the data, by default "numpy" + concatenated : bool, default: False + Whether to concatenate the data across segments. + return_data_name : str | None, default: None + The name of the data to return. If None and multiple `nodepipeline_variables` are computed, + the first one is returned. + + Returns + ------- + numpy.ndarray | dict + The + """ + from spikeinterface.core.sorting_tools import spike_vector_to_indices + + if len(self.nodepipeline_variables) == 1: + return_data_name = self.nodepipeline_variables[0] + else: + if return_data_name is None: + return_data_name = self.nodepipeline_variables[0] + else: + assert ( + return_data_name in self.nodepipeline_variables + ), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}" + + all_data = self.data[return_data_name] + if outputs == "numpy": + return all_data + elif outputs == "by_unit": + unit_ids = self.sorting_analyzer.unit_ids + spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) + spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) + data_by_units = {} + for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): + data_by_units[segment_index] = {} + for unit_id in unit_ids: + inds = spike_indices[segment_index][unit_id] + data_by_units[segment_index][unit_id] = all_data[inds] + + if concatenated: + data_by_units_concatenated = { + unit_id: np.concatenate([data_in_segment[unit_id] for data_in_segment in data_by_units.values()]) + for unit_id in unit_ids + } + return data_by_units_concatenated + + return data_by_units + else: + raise ValueError(f"Wrong .get_data(outputs={outputs}); possibilities are `numpy` or `by_unit`") + + def _select_extension_data(self, unit_ids): + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) + + spikes = self.sorting_analyzer.sorting.to_spike_vector() + keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) + + new_data = dict() + for data_name in self.nodepipeline_variables: + if self.data.get(data_name) is not None: + new_data[data_name] = self.data[data_name][keep_spike_mask] + + return new_data + + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs + ): + new_data = dict() + for data_name in self.nodepipeline_variables: + if self.data.get(data_name) is not None: + if keep_mask is None: + new_data[data_name] = self.data[data_name].copy() + else: + new_data[data_name] = self.data[data_name][keep_mask] + + return new_data + + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # splitting only changes random spikes assignments + return self.data.copy() diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index ce8194f530..8f3ffe0617 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -3,18 +3,14 @@ import numpy as np from spikeinterface.core import ChannelSparsity -from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, ensure_n_jobs, fix_job_kwargs +from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array, _get_nbefore +from spikeinterface.core.sortinganalyzer import register_result_extension +from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension -from spikeinterface.core.template_tools import get_template_extremum_channel +from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, find_parent_of_type -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type - -from spikeinterface.core.template_tools import get_dense_templates_array, _get_nbefore - - -class ComputeAmplitudeScalings(AnalyzerExtension): +class ComputeAmplitudeScalings(BaseSpikeVectorExtension): """ Computes the amplitude scalings from a SortingAnalyzer. @@ -55,31 +51,11 @@ class ComputeAmplitudeScalings(AnalyzerExtension): multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. delta_collision_ms: float, default: 2 The maximum time difference in ms before and after a spike to gather colliding spikes. - load_if_exists : bool, default: False - Whether to load precomputed spike amplitudes, if they already exist. - outputs: "concatenated" | "by_unit", default: "concatenated" - How the output should be returned - {} - - Returns - ------- - amplitude_scalings: np.array or list of dict - The amplitude scalings. - - If "concatenated" all amplitudes for all spikes and all units are concatenated - - If "by_unit", amplitudes are returned as a list (for segments) of dictionaries (for units) """ extension_name = "amplitude_scalings" depend_on = ["templates"] - need_recording = True - use_nodepipeline = True nodepipeline_variables = ["amplitude_scalings", "collision_mask"] - need_job_kwargs = True - - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - - self.collisions = None def _set_params( self, @@ -90,7 +66,7 @@ def _set_params( handle_collisions=True, delta_collision_ms=2, ): - params = dict( + return super()._set_params( sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, @@ -98,38 +74,6 @@ def _set_params( handle_collisions=handle_collisions, delta_collision_ms=delta_collision_ms, ) - return params - - def _select_extension_data(self, unit_ids): - keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) - - spikes = self.sorting_analyzer.sorting.to_spike_vector() - keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) - - new_data = dict() - new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_spike_mask] - if self.params["handle_collisions"]: - new_data["collision_mask"] = self.data["collision_mask"][keep_spike_mask] - return new_data - - def _merge_extension_data( - self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs - ): - new_data = dict() - - if keep_mask is None: - new_data["amplitude_scalings"] = self.data["amplitude_scalings"].copy() - if self.params["handle_collisions"]: - new_data["collision_mask"] = self.data["collision_mask"].copy() - else: - new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_mask] - if self.params["handle_collisions"]: - new_data["collision_mask"] = self.data["collision_mask"][keep_mask] - - return new_data - - def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): - return self.data.copy() def _get_pipeline_nodes(self): @@ -141,6 +85,7 @@ def _get_pipeline_nodes(self): all_templates = get_dense_templates_array(self.sorting_analyzer, return_in_uV=return_in_uV) nbefore = _get_nbefore(self.sorting_analyzer) nafter = all_templates.shape[1] - nbefore + templates_ext = self.sorting_analyzer.get_extension("templates") # if ms_before / ms_after are set in params then the original templates are shorten if self.params["ms_before"] is not None: @@ -155,7 +100,7 @@ def _get_pipeline_nodes(self): cut_out_after = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0) assert ( cut_out_after <= nafter - ), f"`ms_after` must be smaller than `ms_after` used in WaveformExractor: {we._params['ms_after']}" + ), f"`ms_after` must be smaller than `ms_after` used in templates: {templates_ext.params['ms_after']}" else: cut_out_after = nafter @@ -210,30 +155,6 @@ def _get_pipeline_nodes(self): nodes = [spike_retriever_node, amplitude_scalings_node] return nodes - def _run(self, verbose=False, **job_kwargs): - job_kwargs = fix_job_kwargs(job_kwargs) - nodes = self.get_pipeline_nodes() - amp_scalings, collision_mask = run_node_pipeline( - self.sorting_analyzer.recording, - nodes, - job_kwargs=job_kwargs, - job_name="amplitude_scalings", - gather_mode="memory", - verbose=verbose, - ) - self.data["amplitude_scalings"] = amp_scalings - if self.params["handle_collisions"]: - self.data["collision_mask"] = collision_mask - # TODO: make collisions "global" - # for collision in collisions: - # collisions_dict.update(collision) - # self.collisions = collisions_dict - # # Note: collisions are note in _extension_data because they are not pickable. We only store the indices - # self._extension_data["collisions"] = np.array(list(collisions_dict.keys())) - - def _get_data(self): - return self.data[f"amplitude_scalings"] - register_result_extension(ComputeAmplitudeScalings) compute_amplitude_scalings = ComputeAmplitudeScalings.function_factory() diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 959103d922..4fbeabca88 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -2,18 +2,14 @@ import numpy as np -from spikeinterface.core.job_tools import fix_job_kwargs - +from spikeinterface.core.sortinganalyzer import register_result_extension +from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift - -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type -from spikeinterface.core.sorting_tools import spike_vector_to_indices +from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, find_parent_of_type -class ComputeSpikeAmplitudes(AnalyzerExtension): +class ComputeSpikeAmplitudes(BaseSpikeVectorExtension): """ - AnalyzerExtension Computes the spike amplitudes. Needs "templates" to be computed first. @@ -25,59 +21,16 @@ class ComputeSpikeAmplitudes(AnalyzerExtension): A SortingAnalyzer object peak_sign : "neg" | "pos" | "both", default: "neg" Sign of the template to compute extremum channel used to retrieve spike amplitudes. - - Returns - ------- - spike_amplitudes: np.array - All amplitudes for all spikes and all units are concatenated (along time, like in spike vector) - """ extension_name = "spike_amplitudes" depend_on = ["templates"] - need_recording = True - use_nodepipeline = True nodepipeline_variables = ["amplitudes"] - need_job_kwargs = True - - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - - self._all_spikes = None def _set_params(self, peak_sign="neg"): - params = dict(peak_sign=peak_sign) - return params - - def _select_extension_data(self, unit_ids): - keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) - - spikes = self.sorting_analyzer.sorting.to_spike_vector() - keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) - - new_data = dict() - new_data["amplitudes"] = self.data["amplitudes"][keep_spike_mask] - - return new_data - - def _merge_extension_data( - self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs - ): - new_data = dict() - - if keep_mask is None: - new_data["amplitudes"] = self.data["amplitudes"].copy() - else: - new_data["amplitudes"] = self.data["amplitudes"][keep_mask] - - return new_data - - def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): - # splitting only changes random spikes assignments - return self.data.copy() + return super()._set_params(peak_sign=peak_sign) def _get_pipeline_nodes(self): - recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting @@ -102,50 +55,8 @@ def _get_pipeline_nodes(self): nodes = [spike_retriever_node, spike_amplitudes_node] return nodes - def _run(self, verbose=False, **job_kwargs): - job_kwargs = fix_job_kwargs(job_kwargs) - nodes = self.get_pipeline_nodes() - amps = run_node_pipeline( - self.sorting_analyzer.recording, - nodes, - job_kwargs=job_kwargs, - job_name="spike_amplitudes", - gather_mode="memory", - verbose=False, - ) - self.data["amplitudes"] = amps - - def _get_data(self, outputs="numpy", concatenated=False): - all_amplitudes = self.data["amplitudes"] - if outputs == "numpy": - return all_amplitudes - elif outputs == "by_unit": - unit_ids = self.sorting_analyzer.unit_ids - spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) - amplitudes_by_units = {} - for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): - amplitudes_by_units[segment_index] = {} - for unit_id in unit_ids: - inds = spike_indices[segment_index][unit_id] - amplitudes_by_units[segment_index][unit_id] = all_amplitudes[inds] - - if concatenated: - amplitudes_by_units_concatenated = { - unit_id: np.concatenate( - [amps_in_segment[unit_id] for amps_in_segment in amplitudes_by_units.values()] - ) - for unit_id in unit_ids - } - return amplitudes_by_units_concatenated - - return amplitudes_by_units - else: - raise ValueError(f"Wrong .get_data(outputs={outputs}); possibilities are `numpy` or `by_unit`") - register_result_extension(ComputeSpikeAmplitudes) - compute_spike_amplitudes = ComputeSpikeAmplitudes.function_factory() diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d7c7045f5a..e0111fd6f8 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -2,14 +2,15 @@ import numpy as np -from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension +from spikeinterface.core.job_tools import _shared_job_kwargs_doc +from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sorting_tools import spike_vector_to_indices -from spikeinterface.core.node_pipeline import SpikeRetriever, run_node_pipeline +from spikeinterface.core.node_pipeline import SpikeRetriever +from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension -class ComputeSpikeLocations(AnalyzerExtension): +class ComputeSpikeLocations(BaseSpikeVectorExtension): """ Localize spikes in 2D or 3D with several methods given the template. @@ -37,9 +38,6 @@ class ComputeSpikeLocations(AnalyzerExtension): The localization method to use method_kwargs : dict, default: dict() Other kwargs depending on the method. - outputs : "concatenated" | "by_unit", default: "concatenated" - The output format - {} Returns ------- @@ -49,13 +47,7 @@ class ComputeSpikeLocations(AnalyzerExtension): extension_name = "spike_locations" depend_on = ["templates"] - need_recording = True - use_nodepipeline = True nodepipeline_variables = ["spike_locations"] - need_job_kwargs = True - - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) def _set_params( self, @@ -72,40 +64,13 @@ def _set_params( ) if spike_retriver_kwargs is not None: spike_retriver_kwargs_.update(spike_retriver_kwargs) - params = dict( + return super()._set_params( ms_before=ms_before, ms_after=ms_after, spike_retriver_kwargs=spike_retriver_kwargs_, method=method, method_kwargs=method_kwargs, ) - return params - - def _select_extension_data(self, unit_ids): - old_unit_ids = self.sorting_analyzer.unit_ids - unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - spikes = self.sorting_analyzer.sorting.to_spike_vector() - - spike_mask = np.isin(spikes["unit_index"], unit_inds) - new_spike_locations = self.data["spike_locations"][spike_mask] - return dict(spike_locations=new_spike_locations) - - def _merge_extension_data( - self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs - ): - - if keep_mask is None: - new_spike_locations = self.data["spike_locations"].copy() - else: - new_spike_locations = self.data["spike_locations"][keep_mask] - - ### In theory here, we should recompute the locations since the peak positions - ### in a merged could be different. Should be discussed - return dict(spike_locations=new_spike_locations) - - def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): - # splitting only changes random spikes assignments - return self.data.copy() def _get_pipeline_nodes(self): from spikeinterface.sortingcomponents.peak_localization import get_localization_pipeline_nodes @@ -133,49 +98,6 @@ def _get_pipeline_nodes(self): ) return nodes - def _run(self, verbose=False, **job_kwargs): - job_kwargs = fix_job_kwargs(job_kwargs) - nodes = self.get_pipeline_nodes() - spike_locations = run_node_pipeline( - self.sorting_analyzer.recording, - nodes, - job_kwargs=job_kwargs, - job_name="spike_locations", - gather_mode="memory", - verbose=verbose, - ) - self.data["spike_locations"] = spike_locations - - def _get_data(self, outputs="numpy", concatenated=False): - all_spike_locations = self.data["spike_locations"] - if outputs == "numpy": - return all_spike_locations - elif outputs == "by_unit": - unit_ids = self.sorting_analyzer.unit_ids - spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) - spike_locations_by_units = {} - for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): - spike_locations_by_units[segment_index] = {} - for unit_id in unit_ids: - inds = spike_indices[segment_index][unit_id] - spike_locations_by_units[segment_index][unit_id] = all_spike_locations[inds] - - if concatenated: - locations_by_units_concatenated = { - unit_id: np.concatenate( - [locs_in_segment[unit_id] for locs_in_segment in spike_locations_by_units.values()] - ) - for unit_id in unit_ids - } - return locations_by_units_concatenated - - return spike_locations_by_units - else: - raise ValueError(f"Wrong .get_data(outputs={outputs})") - - -ComputeSpikeLocations.__doc__.format(_shared_job_kwargs_doc) register_result_extension(ComputeSpikeLocations) compute_spike_locations = ComputeSpikeLocations.function_factory()