Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions src/spikeinterface/preprocessing/silence_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from spikeinterface.core.core_tools import define_function_handling_dict_from_class
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment

from spikeinterface.core import get_random_data_chunks, get_noise_levels
from spikeinterface.core import get_noise_levels
from spikeinterface.core.generate import NoiseGeneratorRecording
from spikeinterface.core.job_tools import split_job_kwargs


class SilencedPeriodsRecording(BasePreprocessor):
Expand Down Expand Up @@ -36,15 +37,23 @@ class SilencedPeriodsRecording(BasePreprocessor):
- "noise": The periods are filled with a gaussion noise that has the
same variance that the one in the recordings, on a per channel
basis
**random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function
**noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function

Returns
-------
silence_recording : SilencedPeriodsRecording
The recording extractor after silencing some periods
"""

def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, seed=None, **random_chunk_kwargs):
def __init__(
self,
recording,
list_periods,
mode="zeros",
noise_levels=None,
seed=None,
**noise_levels_kwargs,
):
available_modes = ("zeros", "noise")
num_seg = recording.get_num_segments()

Expand All @@ -71,11 +80,9 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see

if mode in ["noise"]:
if noise_levels is None:
random_slices_kwargs = random_chunk_kwargs.copy()
random_slices_kwargs["seed"] = seed
noise_levels = get_noise_levels(
recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs
)
noise_levels_kwargs["return_in_uV"] = False
noise_levels_kwargs["seed"] = seed
noise_levels = get_noise_levels(recording, **noise_levels_kwargs)
noise_generator = NoiseGeneratorRecording(
num_channels=recording.get_num_channels(),
sampling_frequency=recording.sampling_frequency,
Expand All @@ -97,8 +104,10 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see
rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index)
self.add_recording_segment(rec_segment)

self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed)
self._kwargs.update(random_chunk_kwargs)
self._kwargs = dict(
recording=recording, list_periods=list_periods, mode=mode, seed=seed, noise_levels=noise_levels
)
self._kwargs.update(noise_levels_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't play with the _kwargs. Will this update lead to incompatible keys between spikeinterface versions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is a good question, and we should discuss this PR with @samuelgarcia . Indeed, moving the computation of noise levels into the detect_peaks() and not in the node themselves is the only option to control finely the job_kwargs. Not sure we need to save the noise_level_kwargs because the noise_levels are cached per recording I think, and not recomputed during parallel processing. I'll double check with @samuelgarcia



class SilencedPeriodsRecordingSegment(BasePreprocessorSegment):
Expand Down