From 04ca5833e101bcd360dd18291bbf2f0912fda92c Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 17 Jul 2025 09:37:31 +0200 Subject: [PATCH 01/14] WIP --- src/spikeinterface/preprocessing/silence_periods.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index db4533b659..1750f62601 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -5,8 +5,9 @@ from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from spikeinterface.core import get_random_data_chunks, get_noise_levels +from spikeinterface.core import get_noise_levels from spikeinterface.core.generate import NoiseGeneratorRecording +from spikeinterface.core.job_tools import split_job_kwargs class SilencedPeriodsRecording(BasePreprocessor): @@ -48,6 +49,8 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see available_modes = ("zeros", "noise") num_seg = recording.get_num_segments() + random_chunk_kwargs, job_kwargs = split_job_kwargs(random_chunk_kwargs) + if num_seg == 1: if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: # when unique segment accept list instead of of list of list/arrays @@ -74,7 +77,7 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see random_slices_kwargs = random_chunk_kwargs.copy() random_slices_kwargs["seed"] = seed noise_levels = get_noise_levels( - recording, return_scaled=False, random_slices_kwargs=random_slices_kwargs + recording, return_scaled=False, random_slices_kwargs=random_slices_kwargs, **job_kwargs ) noise_generator = NoiseGeneratorRecording( num_channels=recording.get_num_channels(), @@ -97,7 +100,7 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed) + self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed, noise_levels=noise_levels) self._kwargs.update(random_chunk_kwargs) From 983bf3d1d69b5ff8cc8c7a291c1d28ab4daf134b Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 17 Jul 2025 09:40:37 +0200 Subject: [PATCH 02/14] Fix --- src/spikeinterface/preprocessing/silence_periods.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 1750f62601..524d965ca1 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -37,6 +37,12 @@ class SilencedPeriodsRecording(BasePreprocessor): - "noise": The periods are filled with a gaussion noise that has the same variance that the one in the recordings, on a per channel basis + job_kwargs : dict + Keyword arguments for the joblib parallelization. If you want to use + `job_kwargs`, you need to pass them as a dictionary with the key "job_kwargs". + For example, `job_kwargs={"num_workers": 4}`. + Note that this is not used for the `get_noise_levels` function, which has its own + `random_slices_kwargs` argument. **random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function Returns @@ -45,12 +51,10 @@ class SilencedPeriodsRecording(BasePreprocessor): The recording extractor after silencing some periods """ - def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, seed=None, **random_chunk_kwargs): + def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, seed=None, job_kwargs=dict(), **random_chunk_kwargs): available_modes = ("zeros", "noise") num_seg = recording.get_num_segments() - random_chunk_kwargs, job_kwargs = split_job_kwargs(random_chunk_kwargs) - if num_seg == 1: if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: # when unique segment accept list instead of of list of list/arrays From f841e376ac42e2343cdd7440375e777a5e942bcc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Jul 2025 07:44:37 +0000 Subject: [PATCH 03/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../preprocessing/silence_periods.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 524d965ca1..0fd608bfaf 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -51,7 +51,16 @@ class SilencedPeriodsRecording(BasePreprocessor): The recording extractor after silencing some periods """ - def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, seed=None, job_kwargs=dict(), **random_chunk_kwargs): + def __init__( + self, + recording, + list_periods, + mode="zeros", + noise_levels=None, + seed=None, + job_kwargs=dict(), + **random_chunk_kwargs, + ): available_modes = ("zeros", "noise") num_seg = recording.get_num_segments() @@ -104,7 +113,9 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed, noise_levels=noise_levels) + self._kwargs = dict( + recording=recording, list_periods=list_periods, mode=mode, seed=seed, noise_levels=noise_levels + ) self._kwargs.update(random_chunk_kwargs) From e03cd618d21b86ac67d8b57860d8a4b2f04c8b9c Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 18 Jul 2025 09:07:14 +0200 Subject: [PATCH 04/14] Fixes also the noise levels in detect_peaks --- .../sortingcomponents/peak_detection.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 8e10f45624..bec2246898 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -111,6 +111,10 @@ def detect_peaks( method_kwargs, job_kwargs = split_job_kwargs(kwargs) job_kwargs["mp_context"] = method_class.preferred_mp_context + if method_class.need_noise_levels: + random_chunk_kwargs = method_kwargs.pop("random_chunk_kwargs", {}) + method_kwargs["noise_levels"] = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs, **job_kwargs) + node0 = method_class(recording, **method_kwargs) nodes = [node0] @@ -384,6 +388,7 @@ class DetectPeakByChannel(PeakDetectorWrapper): name = "by_channel" engine = "numpy" + need_noise_levels = True preferred_mp_context = None params_doc = """ peak_sign: "neg" | "pos" | "both", default: "neg" @@ -410,12 +415,11 @@ def check_params( detect_threshold=5, exclude_sweep_ms=0.1, noise_levels=None, + random_chunk_kwargs={}, ): assert peak_sign in ("both", "neg", "pos") - if noise_levels is None: - noise_levels = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs) abs_thresholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) @@ -466,6 +470,7 @@ class DetectPeakByChannelTorch(PeakDetectorWrapper): name = "by_channel_torch" engine = "torch" + need_noise_levels = True preferred_mp_context = "spawn" params_doc = """ peak_sign: "neg" | "pos" | "both", default: "neg" @@ -510,8 +515,6 @@ def check_params( if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" - if noise_levels is None: - noise_levels = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs) abs_thresholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) @@ -538,6 +541,7 @@ class DetectPeakLocallyExclusive(PeakDetectorWrapper): name = "locally_exclusive" engine = "numba" + need_noise_levels = True preferred_mp_context = None params_doc = ( DetectPeakByChannel.params_doc @@ -571,8 +575,6 @@ def check_params( # ) assert peak_sign in ("both", "neg", "pos") - if noise_levels is None: - noise_levels = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs) abs_thresholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) @@ -633,6 +635,7 @@ class DetectPeakMatchedFiltering(PeakDetector): name = "matched_filtering" engine = "numba" + need_noise_levels = False preferred_mp_context = None params_doc = ( DetectPeakByChannel.params_doc @@ -780,6 +783,7 @@ class DetectPeakLocallyExclusiveTorch(PeakDetectorWrapper): name = "locally_exclusive_torch" engine = "torch" + need_noise_levels = True preferred_mp_context = "spawn" params_doc = ( DetectPeakByChannel.params_doc @@ -1069,6 +1073,7 @@ def _torch_detect_peaks(traces, peak_sign, abs_thresholds, exclude_sweep_size=5, class DetectPeakLocallyExclusiveOpenCL(PeakDetectorWrapper): name = "locally_exclusive_cl" engine = "opencl" + need_noise_levels = True preferred_mp_context = None params_doc = ( DetectPeakLocallyExclusive.params_doc @@ -1091,8 +1096,7 @@ def check_params( ): # TODO refactor with other classes assert peak_sign in ("both", "neg", "pos") - if noise_levels is None: - noise_levels = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs) + abs_thresholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) channel_distance = get_channel_distances(recording) From 75c398158289681db533260b9a9a3ec5a5698669 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jul 2025 07:07:48 +0000 Subject: [PATCH 05/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/peak_detection.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index bec2246898..331e6f676d 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -113,7 +113,9 @@ def detect_peaks( if method_class.need_noise_levels: random_chunk_kwargs = method_kwargs.pop("random_chunk_kwargs", {}) - method_kwargs["noise_levels"] = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs, **job_kwargs) + method_kwargs["noise_levels"] = get_noise_levels( + recording, return_scaled=False, **random_chunk_kwargs, **job_kwargs + ) node0 = method_class(recording, **method_kwargs) nodes = [node0] @@ -415,7 +417,6 @@ def check_params( detect_threshold=5, exclude_sweep_ms=0.1, noise_levels=None, - random_chunk_kwargs={}, ): assert peak_sign in ("both", "neg", "pos") @@ -1096,7 +1097,7 @@ def check_params( ): # TODO refactor with other classes assert peak_sign in ("both", "neg", "pos") - + abs_thresholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) channel_distance = get_channel_distances(recording) From 19380f7f5e2c07f8d64598032bbd0a2a577963bc Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 18 Jul 2025 09:26:42 +0200 Subject: [PATCH 06/14] WIP --- src/spikeinterface/preprocessing/silence_periods.py | 2 +- src/spikeinterface/sortingcomponents/peak_detection.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 0fd608bfaf..aff9b2d43d 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -90,7 +90,7 @@ def __init__( random_slices_kwargs = random_chunk_kwargs.copy() random_slices_kwargs["seed"] = seed noise_levels = get_noise_levels( - recording, return_scaled=False, random_slices_kwargs=random_slices_kwargs, **job_kwargs + recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs, **job_kwargs ) noise_generator = NoiseGeneratorRecording( num_channels=recording.get_num_channels(), diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 331e6f676d..c0d602c1d5 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -113,9 +113,7 @@ def detect_peaks( if method_class.need_noise_levels: random_chunk_kwargs = method_kwargs.pop("random_chunk_kwargs", {}) - method_kwargs["noise_levels"] = get_noise_levels( - recording, return_scaled=False, **random_chunk_kwargs, **job_kwargs - ) + method_kwargs["noise_levels"] = get_noise_levels(recording, return_in_uV=False, **random_chunk_kwargs, **job_kwargs) node0 = method_class(recording, **method_kwargs) nodes = [node0] From 94300dd162156167c694389af59d44cf94a8bf56 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jul 2025 07:27:07 +0000 Subject: [PATCH 07/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/peak_detection.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index c0d602c1d5..5296cdda69 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -113,7 +113,9 @@ def detect_peaks( if method_class.need_noise_levels: random_chunk_kwargs = method_kwargs.pop("random_chunk_kwargs", {}) - method_kwargs["noise_levels"] = get_noise_levels(recording, return_in_uV=False, **random_chunk_kwargs, **job_kwargs) + method_kwargs["noise_levels"] = get_noise_levels( + recording, return_in_uV=False, **random_chunk_kwargs, **job_kwargs + ) node0 = method_class(recording, **method_kwargs) nodes = [node0] From b6097004d28b1d059e8a2ef3c788af52ad242f45 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 18 Jul 2025 09:35:39 +0200 Subject: [PATCH 08/14] Keep local computations of noise_levels if no pipeline --- .../sortingcomponents/peak_detection.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index c0d602c1d5..22eb1158a8 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -419,6 +419,8 @@ def check_params( ): assert peak_sign in ("both", "neg", "pos") + if noise_levels is None: + noise_levels = get_noise_levels(recording, return_in_uV=False, **random_chunk_kwargs) abs_thresholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) @@ -513,7 +515,8 @@ def check_params( assert peak_sign in ("both", "neg", "pos") if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" - + if noise_levels is None: + noise_levels = get_noise_levels(recording, return_in_uV=False, **random_chunk_kwargs) abs_thresholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) @@ -574,6 +577,8 @@ def check_params( # ) assert peak_sign in ("both", "neg", "pos") + if noise_levels is None: + noise_levels = get_noise_levels(recording, return_in_uV=False, **random_chunk_kwargs) abs_thresholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) @@ -661,7 +666,6 @@ def __init__( detect_threshold=5, exclude_sweep_ms=0.1, radius_um=50, - noise_levels=None, random_chunk_kwargs={"num_chunks_per_segment": 5}, weight_method={}, ): @@ -1095,7 +1099,8 @@ def check_params( ): # TODO refactor with other classes assert peak_sign in ("both", "neg", "pos") - + if noise_levels is None: + noise_levels = get_noise_levels(recording, return_in_uV=False, **random_chunk_kwargs) abs_thresholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) channel_distance = get_channel_distances(recording) From 05f10d14f2f40bea20db1d9d214cc1362e8152ef Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 18 Jul 2025 09:44:45 +0200 Subject: [PATCH 09/14] Oups --- src/spikeinterface/sortingcomponents/peak_detection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 64ba49d2f5..c4690fa973 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -668,6 +668,7 @@ def __init__( detect_threshold=5, exclude_sweep_ms=0.1, radius_um=50, + noise_levels=None, random_chunk_kwargs={"num_chunks_per_segment": 5}, weight_method={}, ): From c7934dd430739c393be8f043d161fc7255a9a778 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 31 Jul 2025 10:14:12 +0200 Subject: [PATCH 10/14] Docs --- src/spikeinterface/preprocessing/silence_periods.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index e6c0c329f0..7116f03e42 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -37,12 +37,6 @@ class SilencedPeriodsRecording(BasePreprocessor): - "noise": The periods are filled with a gaussion noise that has the same variance that the one in the recordings, on a per channel basis - job_kwargs : dict - Keyword arguments for the joblib parallelization. If you want to use - `job_kwargs`, you need to pass them as a dictionary with the key "job_kwargs". - For example, `job_kwargs={"num_workers": 4}`. - Note that this is not used for the `get_noise_levels` function, which has its own - `random_slices_kwargs` argument. **random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function Returns From d818bfb8873781b6fa6c8207a8a27926dcb33674 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 31 Jul 2025 10:21:59 +0200 Subject: [PATCH 11/14] More consistent --- src/spikeinterface/preprocessing/silence_periods.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 7116f03e42..5259158e60 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -37,7 +37,7 @@ class SilencedPeriodsRecording(BasePreprocessor): - "noise": The periods are filled with a gaussion noise that has the same variance that the one in the recordings, on a per channel basis - **random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function + **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function Returns ------- @@ -52,7 +52,7 @@ def __init__( mode="zeros", noise_levels=None, seed=None, - **random_chunk_kwargs, + **noise_levels_kwargs, ): available_modes = ("zeros", "noise") num_seg = recording.get_num_segments() @@ -80,10 +80,10 @@ def __init__( if mode in ["noise"]: if noise_levels is None: - random_slices_kwargs, job_kwargs = split_job_kwargs(random_chunk_kwargs) - random_slices_kwargs["seed"] = seed + noise_levels_kwargs["return_in_uV"] = False + noise_levels_kwargs["seed"] = seed noise_levels = get_noise_levels( - recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs, **job_kwargs + recording, **noise_levels_kwargs ) noise_generator = NoiseGeneratorRecording( num_channels=recording.get_num_channels(), @@ -109,7 +109,7 @@ def __init__( self._kwargs = dict( recording=recording, list_periods=list_periods, mode=mode, seed=seed, noise_levels=noise_levels ) - self._kwargs.update(random_chunk_kwargs) + self._kwargs.update(noise_levels_kwargs) class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): From e0a34a7bd317ff0dbbbaaccf04e509753e83d51e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Jul 2025 08:24:02 +0000 Subject: [PATCH 12/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/silence_periods.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 5259158e60..f598cb1949 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -82,9 +82,7 @@ def __init__( if noise_levels is None: noise_levels_kwargs["return_in_uV"] = False noise_levels_kwargs["seed"] = seed - noise_levels = get_noise_levels( - recording, **noise_levels_kwargs - ) + noise_levels = get_noise_levels(recording, **noise_levels_kwargs) noise_generator = NoiseGeneratorRecording( num_channels=recording.get_num_channels(), sampling_frequency=recording.sampling_frequency, From 96754693297ea9f47f2f41192ec57192422fdd5f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Sep 2025 08:56:37 +0000 Subject: [PATCH 13/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/extractors/phykilosortextractors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 8209d93543..46a8e4cecb 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -40,7 +40,7 @@ class BasePhyKilosortSortingExtractor(BaseSorting): The cluster_id column is used as the merge key to combine properties from multiple files. All loaded properties are added to the sorting extractor as unit properties, with some renamed for SpikeInterface conventions: 'group' becomes 'quality', 'cluster_id' - becomes 'original_cluster_id'. These properties can be accessed via ``sorting.get_property()`` + becomes 'original_cluster_id'. These properties can be accessed via ``sorting.get_property()`` function. """ From 73bef41a69f1a2b3fcaa22685dd649196179ec8c Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 12 Nov 2025 10:44:14 +0100 Subject: [PATCH 14/14] WIP --- .../sortingcomponents/peak_detection.py | 1338 ----------------- 1 file changed, 1338 deletions(-) delete mode 100644 src/spikeinterface/sortingcomponents/peak_detection.py diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py deleted file mode 100644 index b2b5bcc154..0000000000 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ /dev/null @@ -1,1338 +0,0 @@ -"""Sorting components: peak detection.""" - -from __future__ import annotations - - -import copy -from typing import Tuple, List, Optional -import importlib.util - -import numpy as np - -from spikeinterface.core.job_tools import ( - _shared_job_kwargs_doc, - split_job_kwargs, -) -from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances, get_random_data_chunks - -from spikeinterface.core.baserecording import BaseRecording -from spikeinterface.core.node_pipeline import ( - PeakDetector, - WaveformsNode, - ExtractSparseWaveforms, - run_node_pipeline, - base_peak_dtype, -) - -from spikeinterface.postprocessing.localization_tools import get_convolution_weights - -from .tools import make_multi_method_doc - -numba_spec = importlib.util.find_spec("numba") -if numba_spec is not None: - HAVE_NUMBA = True -else: - HAVE_NUMBA = False - -torch_spec = importlib.util.find_spec("torch") -if torch_spec is not None: - torch_nn_functional_spec = importlib.util.find_spec("torch.nn") - if torch_nn_functional_spec is not None: - HAVE_TORCH = True - else: - HAVE_TORCH = False -else: - HAVE_TORCH = False - -""" -TODO: - * remove the wrapper class and move all implementation to instance -""" - - -def detect_peaks( - recording, - method="locally_exclusive", - pipeline_nodes=None, - gather_mode="memory", - folder=None, - names=None, - skip_after_n_peaks=None, - recording_slices=None, - **kwargs, -): - """Peak detection based on threshold crossing in term of k x MAD. - - In "by_channel" : peak are detected in each channel independently - In "locally_exclusive" : a single best peak is taken from a set of neighboring channels - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor object. - pipeline_nodes : None or list[PipelineNode] - Optional additional PipelineNode need to computed just after detection time. - This avoid reading the recording multiple times. - gather_mode : str - How to gather the results: - * "memory": results are returned as in-memory numpy arrays - * "npy": results are stored to .npy files in `folder` - - folder : str or Path - If gather_mode is "npy", the folder where the files are created. - names : list - List of strings with file stems associated with returns. - skip_after_n_peaks : None | int - Skip the computation after n_peaks. - This is not an exact because internally this skip is done per worker in average. - recording_slices : None | list[tuple] - Optionaly give a list of slices to run the pipeline only on some chunks of the recording. - It must be a list of (segment_index, frame_start, frame_stop). - If None (default), the function iterates over the entire duration of the recording. - - {method_doc} - {job_doc} - - Returns - ------- - peaks: array - Detected peaks. - - Notes - ----- - This peak detection ported from tridesclous into spikeinterface. - - """ - - assert method in detect_peak_methods - - method_class = detect_peak_methods[method] - - method_kwargs, job_kwargs = split_job_kwargs(kwargs) - job_kwargs["mp_context"] = method_class.preferred_mp_context - - if method_class.need_noise_levels: - random_chunk_kwargs = method_kwargs.pop("random_chunk_kwargs", {}) - method_kwargs["noise_levels"] = get_noise_levels( - recording, return_in_uV=False, **random_chunk_kwargs, **job_kwargs - ) - - node0 = method_class(recording, **method_kwargs) - nodes = [node0] - - job_name = f"detect peaks using {method}" - if pipeline_nodes is None: - squeeze_output = True - else: - squeeze_output = False - if len(pipeline_nodes) == 1: - plural = "" - else: - plural = "s" - job_name += f" + {len(pipeline_nodes)} node{plural}" - - # because node are modified inplace (insert parent) they need to copy incase - # the same pipeline is run several times - pipeline_nodes = copy.deepcopy(pipeline_nodes) - for node in pipeline_nodes: - if node.parents is None: - node.parents = [node0] - else: - node.parents = [node0] + node.parents - nodes.append(node) - - outs = run_node_pipeline( - recording, - nodes, - job_kwargs, - job_name=job_name, - gather_mode=gather_mode, - squeeze_output=squeeze_output, - folder=folder, - names=names, - skip_after_n_peaks=skip_after_n_peaks, - recording_slices=recording_slices, - ) - return outs - - -expanded_base_peak_dtype = np.dtype(base_peak_dtype + [("iteration", "int8")]) - - -class IterativePeakDetector(PeakDetector): - """ - A class that iteratively detects peaks in the recording by applying a peak detector, waveform extraction, - and waveform denoising node. The algorithm runs for a specified number of iterations or until no peaks are found. - """ - - def __init__( - self, - recording: BaseRecording, - peak_detector_node: PeakDetector, - waveform_extraction_node: WaveformsNode, - waveform_denoising_node, - num_iterations: int = 2, - return_output: bool = True, - tresholds: Optional[List[float]] = None, - ): - """ - Initialize the iterative peak detector. - - Parameters - ---------- - recording : BaseRecording - The recording to process - peak_detector_node : PeakDetector - The peak detector node to use - waveform_extraction_node : WaveformsNode - The waveform extraction node to use - waveform_denoising_node - The waveform denoising node to use - num_iterations : int, default: 2 - The number of iterations to run the algorithm - return_output : bool, default: True - Whether to return the output of the algorithm - """ - PeakDetector.__init__(self, recording, return_output=return_output) - self.peak_detector_node = peak_detector_node - self.waveform_extraction_node = waveform_extraction_node - self.waveform_denoising_node = waveform_denoising_node - self.num_iterations = num_iterations - self.tresholds = tresholds - - def get_trace_margin(self) -> int: - """ - Calculate the maximum trace margin from the internal pipeline. - Using the strategy as use by the Node pipeline - - - Returns - ------- - int - The maximum trace margin. - """ - internal_pipeline = (self.peak_detector_node, self.waveform_extraction_node, self.waveform_denoising_node) - pipeline_margin = (node.get_trace_margin() for node in internal_pipeline if hasattr(node, "get_trace_margin")) - return max(pipeline_margin) - - def compute(self, traces_chunk, start_frame, end_frame, segment_index, max_margin) -> Tuple[np.ndarray, np.ndarray]: - """ - Perform the iterative peak detection, waveform extraction, and denoising. - - Parameters - ---------- - traces_chunk : array-like - The chunk of traces to process. - start_frame : int - The starting frame for the chunk. - end_frame : int - The ending frame for the chunk. - segment_index : int - The segment index. - max_margin : int - The maximum margin for the traces. - - Returns - ------- - tuple of ndarray - A tuple containing a single ndarray with the detected peaks. - """ - - traces_chunk = np.array(traces_chunk, copy=True, dtype="float32") - local_peaks_list = [] - all_waveforms = [] - - for iteration in range(self.num_iterations): - # Hack because of lack of either attribute or named references - # I welcome suggestions on how to improve this but I think it is an architectural issue - if self.tresholds is not None: - old_args = self.peak_detector_node.args - old_detect_treshold = self.peak_detector_node.params["detect_threshold"] - old_abs_treshold = old_args[1] - new_abs_treshold = old_abs_treshold * self.tresholds[iteration] / old_detect_treshold - - new_args = tuple(val if index != 1 else new_abs_treshold for index, val in enumerate(old_args)) - self.peak_detector_node.args = new_args - - (local_peaks,) = self.peak_detector_node.compute( - traces=traces_chunk, - start_frame=start_frame, - end_frame=end_frame, - segment_index=segment_index, - max_margin=max_margin, - ) - - local_peaks = self.add_iteration_to_peaks_dtype(local_peaks=local_peaks, iteration=iteration) - local_peaks_list.append(local_peaks) - - # End algorith if no peak is found - if local_peaks.size == 0: - break - - waveforms = self.waveform_extraction_node.compute(traces=traces_chunk, peaks=local_peaks) - denoised_waveforms = self.waveform_denoising_node.compute( - traces=traces_chunk, peaks=local_peaks, waveforms=waveforms - ) - - self.substract_waveforms_from_traces( - local_peaks=local_peaks, - traces_chunk=traces_chunk, - waveforms=denoised_waveforms, - ) - - all_waveforms.append(waveforms) - all_local_peaks = np.concatenate(local_peaks_list, axis=0) - all_waveforms = np.concatenate(all_waveforms, axis=0) if len(all_waveforms) != 0 else np.empty((0, 0, 0)) - - # Sort as iterative method implies peaks might not be discovered ordered in time - sorting_indices = np.argsort(all_local_peaks["sample_index"]) - all_local_peaks = all_local_peaks[sorting_indices] - all_waveforms = all_waveforms[sorting_indices] - - return (all_local_peaks, all_waveforms) - - def substract_waveforms_from_traces( - self, - local_peaks: np.ndarray, - traces_chunk: np.ndarray, - waveforms: np.ndarray, - ): - """ - Substract inplace the cleaned waveforms from the traces_chunk. - - Parameters - ---------- - sample_indices : ndarray - The indices where the waveforms are maximum (peaks["sample_index"]). - traces_chunk : ndarray - A chunk of the traces. - waveforms : ndarray - The waveforms extracted from the traces. - """ - - nbefore = self.waveform_extraction_node.nbefore - nafter = self.waveform_extraction_node.nafter - if isinstance(self.waveform_extraction_node, ExtractSparseWaveforms): - neighbours_mask = self.waveform_extraction_node.neighbours_mask - else: - neighbours_mask = None - - for peak_index, peak in enumerate(local_peaks): - center_sample = peak["sample_index"] - first_sample = center_sample - nbefore - last_sample = center_sample + nafter - if neighbours_mask is None: - traces_chunk[first_sample:last_sample, :] -= waveforms[peak_index, :, :] - else: - (channels,) = np.nonzero(neighbours_mask[peak["channel_index"]]) - traces_chunk[first_sample:last_sample, channels] -= waveforms[peak_index, :, : len(channels)] - - def add_iteration_to_peaks_dtype(self, local_peaks, iteration) -> np.ndarray: - """ - Add the iteration number to the peaks dtype. - - Parameters - ---------- - local_peaks : ndarray - The array of local peaks. - iteration : int - The iteration number. - - Returns - ------- - ndarray - An array of local peaks with the iteration number added. - """ - # Expand dtype to also contain an iteration field - local_peaks_expanded = np.zeros_like(local_peaks, dtype=expanded_base_peak_dtype) - fields_in_base_type = np.dtype(base_peak_dtype).names - for field in fields_in_base_type: - local_peaks_expanded[field] = local_peaks[field] - local_peaks_expanded["iteration"] = iteration - - return local_peaks_expanded - - -class PeakDetectorWrapper(PeakDetector): - # transitory class to maintain instance based and class method based - # TODO later when in main: refactor in every old detector class: - # * check_params - # * get_method_margin - # and move the logic in the init - # but keep the class method "detect_peaks()" because it is convinient in template matching - def __init__(self, recording, **params): - PeakDetector.__init__(self, recording, return_output=True) - - self.params = params - self.args = self.check_params(recording, **params) - - def get_trace_margin(self): - return self.get_method_margin(*self.args) - - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - peak_sample_ind, peak_chan_ind = self.detect_peaks(traces, *self.args) - if peak_sample_ind.size == 0 or peak_chan_ind.size == 0: - return (np.zeros(0, dtype=base_peak_dtype),) - - peak_amplitude = traces[peak_sample_ind, peak_chan_ind] - local_peaks = np.zeros(peak_sample_ind.size, dtype=base_peak_dtype) - local_peaks["sample_index"] = peak_sample_ind - local_peaks["channel_index"] = peak_chan_ind - local_peaks["amplitude"] = peak_amplitude - local_peaks["segment_index"] = segment_index - - # return is always a tuple - return (local_peaks,) - - -class DetectPeakByChannel(PeakDetectorWrapper): - """Detect peaks using the "by channel" method.""" - - name = "by_channel" - engine = "numpy" - need_noise_levels = True - preferred_mp_context = None - params_doc = """ - peak_sign: "neg" | "pos" | "both", default: "neg" - Sign of the peak - detect_threshold: float, default: 5 - Threshold, in median absolute deviations (MAD), to use to detect peaks - exclude_sweep_ms: float, default: 0.1 - Time, in ms, during which the peak is isolated. Exclusive param with exclude_sweep_size - For example, if `exclude_sweep_ms` is 0.1, a peak is detected if a sample crosses the threshold, - and no larger peaks are located during the 0.1ms preceding and following the peak - noise_levels: array or None, default: None - Estimated noise levels to use, if already computed - If not provide then it is estimated from a random snippet of the data - random_chunk_kwargs: dict, default: dict() - A dict that contain option to randomize chunk for get_noise_levels(). - Only used if noise_levels is None - """ - - @classmethod - def check_params( - cls, - recording, - peak_sign="neg", - detect_threshold=5, - exclude_sweep_ms=0.1, - noise_levels=None, - random_chunk_kwargs={}, - ): - assert peak_sign in ("both", "neg", "pos") - - if noise_levels is None: - noise_levels = get_noise_levels(recording, return_in_uV=False, **random_chunk_kwargs) - abs_thresholds = noise_levels * detect_threshold - exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) - - return (peak_sign, abs_thresholds, exclude_sweep_size) - - @classmethod - def get_method_margin(cls, *args): - exclude_sweep_size = args[2] - return exclude_sweep_size - - @classmethod - def detect_peaks(cls, traces, peak_sign, abs_thresholds, exclude_sweep_size): - traces_center = traces[exclude_sweep_size:-exclude_sweep_size, :] - length = traces_center.shape[0] - - if peak_sign in ("pos", "both"): - peak_mask = traces_center > abs_thresholds[None, :] - for i in range(exclude_sweep_size): - peak_mask &= traces_center > traces[i : i + length, :] - peak_mask &= ( - traces_center >= traces[exclude_sweep_size + i + 1 : exclude_sweep_size + i + 1 + length, :] - ) - - if peak_sign in ("neg", "both"): - if peak_sign == "both": - peak_mask_pos = peak_mask.copy() - - peak_mask = traces_center < -abs_thresholds[None, :] - for i in range(exclude_sweep_size): - peak_mask &= traces_center < traces[i : i + length, :] - peak_mask &= ( - traces_center <= traces[exclude_sweep_size + i + 1 : exclude_sweep_size + i + 1 + length, :] - ) - - if peak_sign == "both": - peak_mask = peak_mask | peak_mask_pos - - # find peaks - peak_sample_ind, peak_chan_ind = np.nonzero(peak_mask) - # correct for time shift - peak_sample_ind += exclude_sweep_size - - return peak_sample_ind, peak_chan_ind - - -class DetectPeakByChannelTorch(PeakDetectorWrapper): - """Detect peaks using the "by channel" method with pytorch.""" - - name = "by_channel_torch" - engine = "torch" - need_noise_levels = True - preferred_mp_context = "spawn" - params_doc = """ - peak_sign: "neg" | "pos" | "both", default: "neg" - Sign of the peak - detect_threshold: float, default: 5 - Threshold, in median absolute deviations (MAD), to use to detect peaks - exclude_sweep_ms: float, default: 0.1 - Time, in ms, during which the peak is isolated. Exclusive param with exclude_sweep_size - For example, if `exclude_sweep_ms` is 0.1, a peak is detected if a sample crosses the threshold, - and no larger peaks are located during the 0.1ms preceding and following the peak - noise_levels: array or None, default: None - Estimated noise levels to use, if already computed. - If not provide then it is estimated from a random snippet of the data - device : str or None, default: None - "cpu", "cuda", or None. If None and cuda is available, "cuda" is selected - return_tensor : bool, default: False - If True, the output is returned as a tensor, otherwise as a numpy array - random_chunk_kwargs: dict, default: dict() - A dict that contain option to randomize chunk for get_noise_levels(). - Only used if noise_levels is None. - """ - - @classmethod - def check_params( - cls, - recording, - peak_sign="neg", - detect_threshold=5, - exclude_sweep_ms=0.1, - noise_levels=None, - device=None, - return_tensor=False, - random_chunk_kwargs={}, - ): - - if not HAVE_TORCH: - raise ModuleNotFoundError('"by_channel_torch" needs torch which is not installed') - - import torch.cuda - - assert peak_sign in ("both", "neg", "pos") - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - if noise_levels is None: - noise_levels = get_noise_levels(recording, return_in_uV=False, **random_chunk_kwargs) - abs_thresholds = noise_levels * detect_threshold - exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) - - return (peak_sign, abs_thresholds, exclude_sweep_size, device, return_tensor) - - @classmethod - def get_method_margin(cls, *args): - exclude_sweep_size = args[2] - return exclude_sweep_size - - @classmethod - def detect_peaks(cls, traces, peak_sign, abs_thresholds, exclude_sweep_size, device, return_tensor): - sample_inds, chan_inds = _torch_detect_peaks( - traces, peak_sign, abs_thresholds, exclude_sweep_size, None, device - ) - if not return_tensor: - sample_inds = np.array(sample_inds.cpu()) - chan_inds = np.array(chan_inds.cpu()) - return sample_inds, chan_inds - - -class DetectPeakLocallyExclusive(PeakDetectorWrapper): - """Detect peaks using the "locally exclusive" method.""" - - name = "locally_exclusive" - engine = "numba" - need_noise_levels = True - preferred_mp_context = None - params_doc = ( - DetectPeakByChannel.params_doc - + """ - radius_um: float - The radius to use to select neighbour channels for locally exclusive detection. - """ - ) - - @classmethod - def check_params( - cls, - recording, - peak_sign="neg", - detect_threshold=5, - exclude_sweep_ms=0.1, - radius_um=50, - noise_levels=None, - random_chunk_kwargs={}, - ): - if not HAVE_NUMBA: - raise ModuleNotFoundError('"locally_exclusive" needs numba which is not installed') - - # args = DetectPeakByChannel.check_params( - # recording, - # peak_sign=peak_sign, - # detect_threshold=detect_threshold, - # exclude_sweep_ms=exclude_sweep_ms, - # noise_levels=noise_levels, - # random_chunk_kwargs=random_chunk_kwargs, - # ) - - assert peak_sign in ("both", "neg", "pos") - if noise_levels is None: - noise_levels = get_noise_levels(recording, return_in_uV=False, **random_chunk_kwargs) - abs_thresholds = noise_levels * detect_threshold - exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) - - # if remove_median: - - # chunks = get_random_data_chunks(recording, return_in_uV=False, concatenated=True, **random_chunk_kwargs) - # medians = np.median(chunks, axis=0) - # medians = medians[None, :] - # print('medians', medians, noise_levels) - # else: - # medians = None - - channel_distance = get_channel_distances(recording) - neighbours_mask = channel_distance <= radius_um - return (peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask) - - @classmethod - def get_method_margin(cls, *args): - exclude_sweep_size = args[2] - return exclude_sweep_size - - @classmethod - def detect_peaks(cls, traces, peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask): - assert HAVE_NUMBA, "You need to install numba" - - # if medians is not None: - # traces = traces - medians - - traces_center = traces[exclude_sweep_size:-exclude_sweep_size, :] - - if peak_sign in ("pos", "both"): - peak_mask = traces_center > abs_thresholds[None, :] - peak_mask = _numba_detect_peak_pos( - traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask - ) - - if peak_sign in ("neg", "both"): - if peak_sign == "both": - peak_mask_pos = peak_mask.copy() - - peak_mask = traces_center < -abs_thresholds[None, :] - peak_mask = _numba_detect_peak_neg( - traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask - ) - - if peak_sign == "both": - peak_mask = peak_mask | peak_mask_pos - - # Find peaks and correct for time shift - peak_sample_ind, peak_chan_ind = np.nonzero(peak_mask) - peak_sample_ind += exclude_sweep_size - - return peak_sample_ind, peak_chan_ind - - -class DetectPeakMatchedFiltering(PeakDetector): - """Detect peaks using the 'matched_filtering' method.""" - - name = "matched_filtering" - engine = "numba" - need_noise_levels = False - preferred_mp_context = None - params_doc = ( - DetectPeakByChannel.params_doc - + """ - radius_um : float - The radius to use to select neighbour channels for locally exclusive detection. - prototype : array - The canonical waveform of action potentials - ms_before : float - The time in ms before the maximial value of the absolute prototype - weight_method : dict - Parameter that should be provided to the get_convolution_weights() function - in order to know how to estimate the positions. One argument is mode that could - be either gaussian_2d (KS like) or exponential_3d (default) - """ - ) - - def __init__( - self, - recording, - prototype, - ms_before, - peak_sign="neg", - detect_threshold=5, - exclude_sweep_ms=0.1, - radius_um=50, - noise_levels=None, - random_chunk_kwargs={"num_chunks_per_segment": 5}, - weight_method={}, - ): - PeakDetector.__init__(self, recording, return_output=True) - from scipy.sparse import csr_matrix - - if not HAVE_NUMBA: - raise ModuleNotFoundError('matched_filtering" needs numba which is not installed') - - self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) - channel_distance = get_channel_distances(recording) - self.neighbours_mask = channel_distance <= radius_um - - self.conv_margin = prototype.shape[0] - - assert peak_sign in ("both", "neg", "pos") - self.nbefore = int(ms_before * recording.sampling_frequency / 1000) - if peak_sign == "neg": - assert prototype[self.nbefore] < 0, "Prototype should have a negative peak" - peak_sign = "pos" - elif peak_sign == "pos": - assert prototype[self.nbefore] > 0, "Prototype should have a positive peak" - - self.peak_sign = peak_sign - self.prototype = np.flip(prototype) / np.linalg.norm(prototype) - - contact_locations = recording.get_channel_locations() - dist = np.linalg.norm(contact_locations[:, np.newaxis] - contact_locations[np.newaxis, :], axis=2) - self.weights, self.z_factors = get_convolution_weights(dist, **weight_method) - self.num_z_factors = len(self.z_factors) - self.num_channels = recording.get_num_channels() - self.num_templates = self.num_channels - if peak_sign == "both": - self.weights = np.hstack((self.weights, self.weights)) - self.weights[:, self.num_templates :, :] *= -1 - self.num_templates *= 2 - - self.weights = self.weights.reshape(self.num_templates * self.num_z_factors, -1) - self.weights = csr_matrix(self.weights) - random_data = get_random_data_chunks(recording, return_in_uV=False, **random_chunk_kwargs) - conv_random_data = self.get_convolved_traces(random_data) - medians = np.median(conv_random_data, axis=1) - self.medians = medians[:, None] - noise_levels = np.median(np.abs(conv_random_data - self.medians), axis=1) / 0.6744897501960817 - self.abs_thresholds = noise_levels * detect_threshold - self._dtype = np.dtype(base_peak_dtype + [("z", "float32")]) - - def get_dtype(self): - return self._dtype - - def get_trace_margin(self): - return self.exclude_sweep_size + self.conv_margin - - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - - assert HAVE_NUMBA, "You need to install numba" - conv_traces = self.get_convolved_traces(traces) - # conv_traces -= self.medians - conv_traces /= self.abs_thresholds[:, None] - conv_traces = conv_traces[:, self.conv_margin : -self.conv_margin] - traces_center = conv_traces[:, self.exclude_sweep_size : -self.exclude_sweep_size] - - traces_center = traces_center.reshape(self.num_z_factors, self.num_templates, traces_center.shape[1]) - conv_traces = conv_traces.reshape(self.num_z_factors, self.num_templates, conv_traces.shape[1]) - peak_mask = traces_center > 1 - - peak_mask = _numba_detect_peak_matched_filtering( - conv_traces, - traces_center, - peak_mask, - self.exclude_sweep_size, - self.abs_thresholds, - self.peak_sign, - self.neighbours_mask, - self.num_channels, - ) - - # Find peaks and correct for time shift - z_ind, peak_chan_ind, peak_sample_ind = np.nonzero(peak_mask) - if self.peak_sign == "both": - peak_chan_ind = peak_chan_ind % self.num_channels - - # If we want to estimate z - # peak_chan_ind = peak_chan_ind % num_channels - # z = np.zeros(len(peak_sample_ind), dtype=np.float32) - # for count in range(len(peak_chan_ind)): - # channel = peak_chan_ind[count] - # peak = peak_sample_ind[count] - # data = traces[channel::num_channels, peak] - # z[count] = np.dot(data, z_factors)/data.sum() - - if peak_sample_ind.size == 0 or peak_chan_ind.size == 0: - return (np.zeros(0, dtype=self._dtype),) - - peak_sample_ind += self.exclude_sweep_size + self.conv_margin + self.nbefore - peak_amplitude = traces[peak_sample_ind, peak_chan_ind] - - local_peaks = np.zeros(peak_sample_ind.size, dtype=self._dtype) - local_peaks["sample_index"] = peak_sample_ind - local_peaks["channel_index"] = peak_chan_ind - local_peaks["amplitude"] = peak_amplitude - local_peaks["segment_index"] = segment_index - local_peaks["z"] = z_ind - - # return is always a tuple - return (local_peaks,) - - def get_convolved_traces(self, traces): - from scipy.signal import oaconvolve - - tmp = oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid") - scalar_products = self.weights.dot(tmp) - return scalar_products - - -class DetectPeakLocallyExclusiveTorch(PeakDetectorWrapper): - """Detect peaks using the "locally exclusive" method with pytorch.""" - - name = "locally_exclusive_torch" - engine = "torch" - need_noise_levels = True - preferred_mp_context = "spawn" - params_doc = ( - DetectPeakByChannel.params_doc - + """ - radius_um: float - The radius to use to select neighbour channels for locally exclusive detection. - """ - ) - - @classmethod - def check_params( - cls, - recording, - peak_sign="neg", - detect_threshold=5, - exclude_sweep_ms=0.1, - noise_levels=None, - device=None, - radius_um=50, - return_tensor=False, - random_chunk_kwargs={}, - ): - if not HAVE_TORCH: - raise ModuleNotFoundError('"by_channel_torch" needs torch which is not installed') - args = DetectPeakByChannelTorch.check_params( - recording, - peak_sign=peak_sign, - detect_threshold=detect_threshold, - exclude_sweep_ms=exclude_sweep_ms, - noise_levels=noise_levels, - device=device, - return_tensor=return_tensor, - random_chunk_kwargs=random_chunk_kwargs, - ) - - channel_distance = get_channel_distances(recording) - neighbour_indices_by_chan = [] - num_channels = recording.get_num_channels() - for chan in range(num_channels): - neighbour_indices_by_chan.append(np.nonzero(channel_distance[chan] <= radius_um)[0]) - max_neighbs = np.max([len(neigh) for neigh in neighbour_indices_by_chan]) - neighbours_idxs = num_channels * np.ones((num_channels, max_neighbs), dtype=int) - for i, neigh in enumerate(neighbour_indices_by_chan): - neighbours_idxs[i, : len(neigh)] = neigh - return args + (neighbours_idxs,) - - @classmethod - def get_method_margin(cls, *args): - exclude_sweep_size = args[2] - return exclude_sweep_size - - @classmethod - def detect_peaks(cls, traces, peak_sign, abs_thresholds, exclude_sweep_size, device, return_tensor, neighbor_idxs): - sample_inds, chan_inds = _torch_detect_peaks( - traces, peak_sign, abs_thresholds, exclude_sweep_size, neighbor_idxs, device - ) - if not return_tensor and isinstance(sample_inds, torch.Tensor) and isinstance(chan_inds, torch.Tensor): - sample_inds = np.array(sample_inds.cpu()) - chan_inds = np.array(chan_inds.cpu()) - return sample_inds, chan_inds - - -if HAVE_NUMBA: - import numba - - @numba.jit(nopython=True, parallel=False) - def _numba_detect_peak_pos( - traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask - ): - num_chans = traces_center.shape[1] - for chan_ind in range(num_chans): - for s in range(peak_mask.shape[0]): - if not peak_mask[s, chan_ind]: - continue - for neighbour in range(num_chans): - if not neighbours_mask[chan_ind, neighbour]: - continue - for i in range(exclude_sweep_size): - if chan_ind != neighbour: - peak_mask[s, chan_ind] &= traces_center[s, chan_ind] >= traces_center[s, neighbour] - peak_mask[s, chan_ind] &= traces_center[s, chan_ind] > traces[s + i, neighbour] - peak_mask[s, chan_ind] &= ( - traces_center[s, chan_ind] >= traces[exclude_sweep_size + s + i + 1, neighbour] - ) - if not peak_mask[s, chan_ind]: - break - if not peak_mask[s, chan_ind]: - break - return peak_mask - - @numba.jit(nopython=True, parallel=False) - def _numba_detect_peak_neg( - traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask - ): - num_chans = traces_center.shape[1] - for chan_ind in range(num_chans): - for s in range(peak_mask.shape[0]): - if not peak_mask[s, chan_ind]: - continue - for neighbour in range(num_chans): - if not neighbours_mask[chan_ind, neighbour]: - continue - for i in range(exclude_sweep_size): - if chan_ind != neighbour: - peak_mask[s, chan_ind] &= traces_center[s, chan_ind] <= traces_center[s, neighbour] - peak_mask[s, chan_ind] &= traces_center[s, chan_ind] < traces[s + i, neighbour] - peak_mask[s, chan_ind] &= ( - traces_center[s, chan_ind] <= traces[exclude_sweep_size + s + i + 1, neighbour] - ) - if not peak_mask[s, chan_ind]: - break - if not peak_mask[s, chan_ind]: - break - return peak_mask - - @numba.jit(nopython=True, parallel=False) - def _numba_detect_peak_matched_filtering( - traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_channels - ): - num_z = traces_center.shape[0] - num_templates = traces_center.shape[1] - for template_ind in range(num_templates): - for z in range(num_z): - for s in range(peak_mask.shape[2]): - if not peak_mask[z, template_ind, s]: - continue - for neighbour in range(num_templates): - for j in range(num_z): - if not neighbours_mask[template_ind % num_channels, neighbour % num_channels]: - continue - for i in range(exclude_sweep_size): - if template_ind >= neighbour and z >= j: - peak_mask[z, template_ind, s] &= ( - traces_center[z, template_ind, s] >= traces_center[j, neighbour, s] - ) - else: - peak_mask[z, template_ind, s] &= ( - traces_center[z, template_ind, s] > traces_center[j, neighbour, s] - ) - peak_mask[z, template_ind, s] &= ( - traces_center[z, template_ind, s] > traces[j, neighbour, s + i] - ) - peak_mask[z, template_ind, s] &= ( - traces_center[z, template_ind, s] - >= traces[j, neighbour, exclude_sweep_size + s + i + 1] - ) - if not peak_mask[z, template_ind, s]: - break - if not peak_mask[z, template_ind, s]: - break - if not peak_mask[z, template_ind, s]: - break - - return peak_mask - - -if HAVE_TORCH: - import torch - import torch.nn.functional as F - - @torch.no_grad() - def _torch_detect_peaks(traces, peak_sign, abs_thresholds, exclude_sweep_size=5, neighbours_mask=None, device=None): - """ - Voltage thresholding detection and deduplication with torch. - Implementation from Charlie Windolf: - https://github.com/cwindolf/spike-psvae/blob/ba0a985a075776af892f09adfd453b8d9db168b9/spike_psvae/detect.py#L350 - Parameters - ---------- - traces : np.array - Chunk of traces - abs_thresholds : np.array - Absolute thresholds by channel - peak_sign : "neg" | "pos" | "both", default: "neg" - The sign of the peak to detect peaks - exclude_sweep_size : int, default: 5 - How many temporal neighbors to compare with during argrelmin - Called `order` in original the implementation. The `max_window` parameter, used - for deduplication, is now set as 2* exclude_sweep_size - neighbor_mask : np.array or None, default: None - If given, a matrix with shape (num_channels, num_neighbours) with - neighbour indices for each channel. The matrix needs to be rectangular and - padded to num_channels - device : str or None, default: None - "cpu", "cuda", or None. If None and cuda is available, "cuda" is selected - - Returns - ------- - sample_inds, chan_inds - 1D numpy arrays - """ - # TODO handle GPU-memory at chunk executor level - # for now we keep the same batching mechanism from spike_psvae - # this will be adjusted based on: num jobs, num gpus, num neighbors - MAXCOPY = 8 - - # center traces by excluding the sweep size - traces = traces[exclude_sweep_size:-exclude_sweep_size, :] - num_samples, num_channels = traces.shape - dtype = torch.float32 - empty_return_value = (torch.tensor([], dtype=dtype), torch.tensor([], dtype=dtype)) - - # The function uses maxpooling to look for maximum - if peak_sign == "neg": - traces = -traces - elif peak_sign == "pos": - traces = traces - elif peak_sign == "both": - traces = np.abs(traces) - - traces_tensor = torch.as_tensor(traces, device=device, dtype=torch.float) - thresholds_torch = torch.as_tensor(abs_thresholds, device=device, dtype=torch.float) - normalized_traces = traces_tensor / thresholds_torch - - max_amplitudes, indices = F.max_pool2d_with_indices( - input=normalized_traces[None, None], - kernel_size=[2 * exclude_sweep_size + 1, 1], - stride=1, - padding=[exclude_sweep_size, 0], - ) - max_amplitudes = max_amplitudes[0, 0] - indices = indices[0, 0] - # torch `indices` gives loc of argmax at each position - # find those which actually *were* the max - unique_indices = indices.unique() - window_max_indices = unique_indices[indices.view(-1)[unique_indices] == unique_indices] - - # voltage threshold - max_amplitudes_at_indices = max_amplitudes.view(-1)[window_max_indices] - crossings = torch.nonzero(max_amplitudes_at_indices > 1).squeeze() - if not crossings.numel(): - return empty_return_value - - # -- unravel the spike index - # (right now the indices are into flattened recording) - peak_indices = window_max_indices[crossings] - sample_indices = torch.div(peak_indices, num_channels, rounding_mode="floor") - channel_indices = peak_indices % num_channels - amplitudes = max_amplitudes_at_indices[crossings] - - # we need this due to the padding in convolution - valid_indices = torch.nonzero((0 < sample_indices) & (sample_indices < traces.shape[0] - 1)).squeeze() - if not valid_indices.numel(): - return empty_return_value - sample_indices = sample_indices[valid_indices] - channel_indices = channel_indices[valid_indices] - amplitudes = amplitudes[valid_indices] - - # -- deduplication - # We deduplicate if the channel index is provided. - if neighbours_mask is not None: - neighbours_mask = torch.tensor(neighbours_mask, device=device, dtype=torch.long) - - # -- temporal max pool - # still not sure why we can't just use `max_amplitudes` instead of making - # this sparsely populated array, but it leads to a different result. - max_amplitudes[:] = 0 - max_amplitudes[sample_indices, channel_indices] = amplitudes - max_window = 2 * exclude_sweep_size - max_amplitudes = F.max_pool2d( - max_amplitudes[None, None], - kernel_size=[2 * max_window + 1, 1], - stride=1, - padding=[max_window, 0], - )[0, 0] - - # -- spatial max pool with channel index - # batch size heuristic, see __doc__ - max_neighbs = neighbours_mask.shape[1] - batch_size = int(np.ceil(num_samples / (max_neighbs / MAXCOPY))) - for bs in range(0, num_samples, batch_size): - be = min(num_samples, bs + batch_size) - max_amplitudes[bs:be] = torch.max(F.pad(max_amplitudes[bs:be], (0, 1))[:, neighbours_mask], 2)[0] - - # -- deduplication - deduplication_indices = torch.nonzero( - amplitudes >= max_amplitudes[sample_indices, channel_indices] - 1e-8 - ).squeeze() - if not deduplication_indices.numel(): - return empty_return_value - sample_indices = sample_indices[deduplication_indices] + exclude_sweep_size - channel_indices = channel_indices[deduplication_indices] - amplitudes = amplitudes[deduplication_indices] - - return sample_indices, channel_indices - - -class DetectPeakLocallyExclusiveOpenCL(PeakDetectorWrapper): - name = "locally_exclusive_cl" - engine = "opencl" - need_noise_levels = True - preferred_mp_context = None - params_doc = ( - DetectPeakLocallyExclusive.params_doc - + """ - opencl_context_kwargs: None or dict - kwargs to create the opencl context - """ - ) - - @classmethod - def check_params( - cls, - recording, - peak_sign="neg", - detect_threshold=5, - exclude_sweep_ms=0.1, - radius_um=50, - noise_levels=None, - random_chunk_kwargs={}, - ): - # TODO refactor with other classes - assert peak_sign in ("both", "neg", "pos") - if noise_levels is None: - noise_levels = get_noise_levels(recording, return_in_uV=False, **random_chunk_kwargs) - abs_thresholds = noise_levels * detect_threshold - exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) - channel_distance = get_channel_distances(recording) - neighbours_mask = channel_distance <= radius_um - - executor = OpenCLDetectPeakExecutor(abs_thresholds, exclude_sweep_size, neighbours_mask, peak_sign) - - return (executor,) - - @classmethod - def get_method_margin(cls, *args): - executor = args[0] - return executor.exclude_sweep_size - - @classmethod - def detect_peaks(cls, traces, executor): - peak_sample_ind, peak_chan_ind = executor.detect_peak(traces) - - return peak_sample_ind, peak_chan_ind - - -class OpenCLDetectPeakExecutor: - def __init__(self, abs_thresholds, exclude_sweep_size, neighbours_mask, peak_sign): - - self.chunk_size = None - - self.abs_thresholds = abs_thresholds.astype("float32") - self.exclude_sweep_size = exclude_sweep_size - self.neighbours_mask = neighbours_mask.astype("uint8") - self.peak_sign = peak_sign - self.ctx = None - self.queue = None - self.x = 0 - - def create_buffers_and_compile(self, chunk_size): - import pyopencl - - mf = pyopencl.mem_flags - try: - self.device = pyopencl.get_platforms()[0].get_devices()[0] - self.ctx = pyopencl.Context(devices=[self.device]) - except Exception as e: - print("error create context ", e) - - self.queue = pyopencl.CommandQueue(self.ctx) - self.max_wg_size = self.ctx.devices[0].get_info(pyopencl.device_info.MAX_WORK_GROUP_SIZE) - self.chunk_size = chunk_size - - self.neighbours_mask_cl = pyopencl.Buffer( - self.ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=self.neighbours_mask - ) - self.abs_thresholds_cl = pyopencl.Buffer(self.ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=self.abs_thresholds) - - num_channels = self.neighbours_mask.shape[0] - self.traces_cl = pyopencl.Buffer(self.ctx, mf.READ_WRITE, size=int(chunk_size * num_channels * 4)) - - # TODO estimate smaller - self.num_peaks = np.zeros(1, dtype="int32") - self.num_peaks_cl = pyopencl.Buffer(self.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=self.num_peaks) - - nb_max_spike_in_chunk = num_channels * chunk_size - self.peaks = np.zeros(nb_max_spike_in_chunk, dtype=[("sample_index", "int32"), ("channel_index", "int32")]) - self.peaks_cl = pyopencl.Buffer(self.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=self.peaks) - - variables = dict( - chunk_size=int(self.chunk_size), - exclude_sweep_size=int(self.exclude_sweep_size), - peak_sign={"pos": 1, "neg": -1}[self.peak_sign], - num_channels=num_channels, - ) - - kernel_formated = processor_kernel % variables - prg = pyopencl.Program(self.ctx, kernel_formated) - self.opencl_prg = prg.build() # options='-cl-mad-enable' - self.kern_detect_peaks = getattr(self.opencl_prg, "detect_peaks") - - self.kern_detect_peaks.set_args( - self.traces_cl, self.neighbours_mask_cl, self.abs_thresholds_cl, self.peaks_cl, self.num_peaks_cl - ) - - s = self.chunk_size - 2 * self.exclude_sweep_size - self.global_size = (s,) - self.local_size = None - - def detect_peak(self, traces): - self.x += 1 - - import pyopencl - - if self.chunk_size is None or self.chunk_size != traces.shape[0]: - self.create_buffers_and_compile(traces.shape[0]) - event = pyopencl.enqueue_copy(self.queue, self.traces_cl, traces.astype("float32")) - - pyopencl.enqueue_nd_range_kernel( - self.queue, - self.kern_detect_peaks, - self.global_size, - self.local_size, - ) - - event = pyopencl.enqueue_copy(self.queue, self.traces_cl, traces.astype("float32")) - event = pyopencl.enqueue_copy(self.queue, self.traces_cl, traces.astype("float32")) - event = pyopencl.enqueue_copy(self.queue, self.num_peaks, self.num_peaks_cl) - event = pyopencl.enqueue_copy(self.queue, self.peaks, self.peaks_cl) - event.wait() - - n = self.num_peaks[0] - peaks = self.peaks[:n] - peak_sample_ind = peaks["sample_index"].astype("int64") - peak_chan_ind = peaks["channel_index"].astype("int64") - - return peak_sample_ind, peak_chan_ind - - -processor_kernel = """ -#define chunk_size %(chunk_size)d -#define exclude_sweep_size %(exclude_sweep_size)d -#define peak_sign %(peak_sign)d -#define num_channels %(num_channels)d - - -typedef struct st_peak{ - int sample_index; - int channel_index; -} st_peak; - - -__kernel void detect_peaks( - //in - __global float *traces, - __global uchar *neighbours_mask, - __global float *abs_thresholds, - //out - __global st_peak *peaks, - volatile __global int *num_peaks - ){ - int pos = get_global_id(0); - - if (pos == 0){ - *num_peaks = 0; - } - // this barrier OK if the first group is run first - barrier(CLK_GLOBAL_MEM_FENCE); - - if (pos>=(chunk_size - (2 * exclude_sweep_size))){ - return; - } - - - float v; - uchar peak; - uchar is_neighbour; - - int index; - - int i_peak; - - - for (int chan=0; chanabs_thresholds[chan]){peak=1;} - else {peak=0;} - } - else if(peak_sign==-1){ - if (v<-abs_thresholds[chan]){peak=1;} - else {peak=0;} - } - - if (peak == 1){ - for (int chan_neigh=0; chan_neigh=traces[index]); - } - else if(peak_sign==-1){ - peak = peak && (v<=traces[index]); - } - - if (peak==0){break;} - - if(peak_sign==1){ - for (int i=1; i<=exclude_sweep_size; i++){ - peak = peak && (v>traces[(pos + exclude_sweep_size - i)*num_channels + chan_neigh]) && (v>=traces[(pos + exclude_sweep_size + i)*num_channels + chan_neigh]); - if (peak==0){break;} - } - } - else if(peak_sign==-1){ - for (int i=1; i<=exclude_sweep_size; i++){ - peak = peak && (v