diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index dca57b0c8b..f598cb1949 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -5,8 +5,9 @@ from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from spikeinterface.core import get_random_data_chunks, get_noise_levels +from spikeinterface.core import get_noise_levels from spikeinterface.core.generate import NoiseGeneratorRecording +from spikeinterface.core.job_tools import split_job_kwargs class SilencedPeriodsRecording(BasePreprocessor): @@ -36,7 +37,7 @@ class SilencedPeriodsRecording(BasePreprocessor): - "noise": The periods are filled with a gaussion noise that has the same variance that the one in the recordings, on a per channel basis - **random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function + **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function Returns ------- @@ -44,7 +45,15 @@ class SilencedPeriodsRecording(BasePreprocessor): The recording extractor after silencing some periods """ - def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, seed=None, **random_chunk_kwargs): + def __init__( + self, + recording, + list_periods, + mode="zeros", + noise_levels=None, + seed=None, + **noise_levels_kwargs, + ): available_modes = ("zeros", "noise") num_seg = recording.get_num_segments() @@ -71,11 +80,9 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see if mode in ["noise"]: if noise_levels is None: - random_slices_kwargs = random_chunk_kwargs.copy() - random_slices_kwargs["seed"] = seed - noise_levels = get_noise_levels( - recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs - ) + noise_levels_kwargs["return_in_uV"] = False + noise_levels_kwargs["seed"] = seed + noise_levels = get_noise_levels(recording, **noise_levels_kwargs) noise_generator = NoiseGeneratorRecording( num_channels=recording.get_num_channels(), sampling_frequency=recording.sampling_frequency, @@ -97,8 +104,10 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed) - self._kwargs.update(random_chunk_kwargs) + self._kwargs = dict( + recording=recording, list_periods=list_periods, mode=mode, seed=seed, noise_levels=noise_levels + ) + self._kwargs.update(noise_levels_kwargs) class SilencedPeriodsRecordingSegment(BasePreprocessorSegment):