Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions brainbox/io/one.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,8 +1050,8 @@ def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_
:param **kwargs: kwargs passed to `driftmap()` (optional)
:return:
"""
br = br or BrainRegions()
time_series = time_series or {}
br = BrainRegions() if br is None else br
time_series = {} if time_series is None else time_series
fig, axs = plt.subplots(2, 2, gridspec_kw={
'width_ratios': [.95, .05], 'height_ratios': [.1, .9]}, figsize=(16, 9), sharex='col')
axs[0, 1].set_axis_off()
Expand Down Expand Up @@ -1094,13 +1094,20 @@ def plot_rawdata_snippet(self, sr, spikes, clusters, t0,
save_dir=None,
label='raster',
gain=-93,
title=None):
title=None,
alpha=0.3,
processing='destripe'):

# compute the raw data offset and destripe, we take 400ms around t0
first_sample, last_sample = (int((t0 - 0.2) * sr.fs), int((t0 + 0.2) * sr.fs))
raw = sr[first_sample:last_sample, :-sr.nsync].T
channel_labels = channels['labels'] if (channels is not None) and ('labels' in channels) else True
destriped = ibldsp.voltage.destripe(raw, sr.fs, channel_labels=channel_labels)
if processing == 'destripe':
samples = ibldsp.voltage.destripe(raw, sr.fs, channel_labels=channel_labels)
else:
import scipy.signal
sos = scipy.signal.butter(**{"N": 3, "Wn": 300 / sr.fs * 2, "btype": "highpass"}, output="sos")
samples = scipy.signal.sosfiltfilt(sos, raw)
# filter out the spikes according to good/bad clusters and to the time slice
spike_sel = slice(*np.searchsorted(spikes['samples'], [first_sample, last_sample]))
ss = spikes['samples'][spike_sel]
Expand All @@ -1110,9 +1117,9 @@ def plot_rawdata_snippet(self, sr, spikes, clusters, t0,
title = self._default_plot_title(spikes)
# display the raw data snippet with spikes overlaid
fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9), sharex='col')
Density(destriped, fs=sr.fs, taxis=1, gain=gain, ax=axs[0], t0=t0 - 0.2, unit='s')
axs[0].scatter(ss[sok] / sr.fs, sc[sok], color="green", alpha=0.5)
axs[0].scatter(ss[~sok] / sr.fs, sc[~sok], color="red", alpha=0.5)
Density(samples, fs=sr.fs, taxis=1, gain=gain, ax=axs[0], t0=t0 - 0.2, unit='s')
axs[0].scatter(ss[sok] / sr.fs, sc[sok], color="green", alpha=alpha)
axs[0].scatter(ss[~sok] / sr.fs, sc[~sok], color="red", alpha=alpha)
axs[0].set(title=title, xlim=[t0 - 0.035, t0 + 0.035])
# adds the channel locations if available
if (channels is not None) and ('atlas_id' in channels):
Expand Down Expand Up @@ -1314,7 +1321,7 @@ def _find_behaviour_collection(self, obj):
f'e.g sl.load_{obj}(collection="{collections[0]}")')
raise ALFMultipleCollectionsFound

def load_trials(self, collection=None):
def load_trials(self, collection=None, revision=None):
"""
Function to load trials data into SessionLoader.trials

Expand All @@ -1323,13 +1330,13 @@ def load_trials(self, collection=None):
collection: str
Alf collection of trials data
"""

revision = self.revision if revision is None else revision
if not collection:
collection = self._find_behaviour_collection('trials')
# itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex
self.one.wildcards = False
self.trials = self.one.load_object(
self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*', revision=self.revision or None).to_df()
self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*', revision=revision or None).to_df()
self.one.wildcards = True
self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True

Expand Down
42 changes: 18 additions & 24 deletions ibllib/pipes/ephys_tasks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import importlib
import logging
from pathlib import Path
import re
import shutil
import subprocess
import sys
import traceback

import packaging.version
Expand Down Expand Up @@ -124,7 +126,7 @@ class EphysCompressNP1(base_tasks.EphysTask):
priority = 90
cpu = 2
io_charge = 100 # this jobs reads raw ap files
job_size = 'small'
job_size = 'large'

@property
def signature(self):
Expand Down Expand Up @@ -592,7 +594,6 @@ class SpikeSorting(base_tasks.EphysTask, CellQCMixin):
SHELL_SCRIPT = Path.home().joinpath(
f"Documents/PYTHON/iblscripts/deploy/serverpc/{_sortername}/sort_recording.sh"
)
SPIKE_SORTER_NAME = 'iblsorter'
SORTER_REPOSITORY = Path.home().joinpath('Documents/PYTHON/SPIKE_SORTING/ibl-sorter')

@property
Expand All @@ -608,11 +609,12 @@ def signature(self):
# ./raw_ephys_data/{self.pname}/
('_iblqc_ephysTimeRmsAP.rms.npy', f'{self.device_collection}/{self.pname}/', True),
('_iblqc_ephysTimeRmsAP.timestamps.npy', f'{self.device_collection}/{self.pname}/', True),
('_iblqc_ephysSaturation.samples.npy', f'{self.device_collection}/{self.pname}/', True),
('_iblqc_ephysSaturation.samples.pqt', f'{self.device_collection}/{self.pname}/', True),
# ./spike_sorters/iblsorter/{self.pname}
('_kilosort_raw.output.tar', f'spike_sorters/{self._sortername}/{self.pname}/', True),
# ./alf/{self.pname}/iblsorter
(f'_ibl_log.info_{self.SPIKE_SORTER_NAME}.log', f'alf/{self.pname}/{self._sortername}', True),
(f'{self._sortername}_parameters.yaml', f'alf/{self.pname}/{self._sortername}', True),
(f'_ibl_log.info_{self._sortername}.log', f'alf/{self.pname}/{self._sortername}', True),
('_kilosort_whitening.matrix.npy', f'alf/{self.pname}/{self._sortername}/', True),
('_phy_spikes_subset.channels.npy', f'alf/{self.pname}/{self._sortername}/', True),
('_phy_spikes_subset.spikes.npy', f'alf/{self.pname}/{self._sortername}/', True),
Expand Down Expand Up @@ -657,15 +659,7 @@ def scratch_folder_run(self):
For a scratch drive at /mnt/h0 we would have the following temp dir:
/mnt/h0/iblsorter_1.8.0_CSHL071_2020-10-04_001_probe01/
"""
# get the scratch drive from the shell script
if self.scratch_folder is None:
with open(self.SHELL_SCRIPT) as fid:
lines = fid.readlines()
line = [line for line in lines if line.startswith("SCRATCH_DRIVE=")][0]
m = re.search(r"\=(.*?)(\#|\n)", line)[0]
scratch_drive = Path(m[1:-1].strip())
else:
scratch_drive = self.scratch_folder
scratch_drive = self.scratch_folder if self.scratch_folder else Path('/scratch')
assert scratch_drive.exists(), f"Scratch drive {scratch_drive} not found"
# get the version of the sorter
self.version = self._fetch_iblsorter_version(self.SORTER_REPOSITORY)
Expand Down Expand Up @@ -718,15 +712,15 @@ def _fetch_iblsorter_run_version(log_file):
def _run_iblsort(self, ap_file):
"""
Runs the ks2 matlab spike sorting for one probe dataset
the raw spike sorting output is in session_path/spike_sorters/{self.SPIKE_SORTER_NAME}/probeXX folder
the raw spike sorting output is in session_path/spike_sorters/{self._sortername}/probeXX folder
(discontinued support for old spike sortings in the probe folder <1.5.5)
:return: path of the folder containing ks2 spike sorting output
"""
iblutil.util.setup_logger('iblsorter', level='INFO')
sorter_dir = self.session_path.joinpath("spike_sorters", self.SPIKE_SORTER_NAME, self.pname)
sorter_dir = self.session_path.joinpath("spike_sorters", self._sortername, self.pname)
self.FORCE_RERUN = False
if not self.FORCE_RERUN:
log_file = sorter_dir.joinpath(f"_ibl_log.info_{self.SPIKE_SORTER_NAME}.log")
log_file = sorter_dir.joinpath(f"_ibl_log.info_{self._sortername}.log")
if log_file.exists():
run_version = self._fetch_iblsorter_run_version(log_file)
if packaging.version.parse(run_version) >= packaging.version.parse('1.7.0'):
Expand All @@ -737,11 +731,11 @@ def _run_iblsort(self, ap_file):
self.FORCE_RERUN = True
self.scratch_folder_run.mkdir(parents=True, exist_ok=True)
check_nvidia_driver()
try:
# if pykilosort is in the environment, use the installed version within the task
# this is the best way I found to check if iblsorter is installed and available without a try block
if 'iblsorter' in sys.modules and importlib.util.find_spec('iblsorter.ibl') is not None:
import iblsorter.ibl # noqa
iblsorter.ibl.run_spike_sorting_ibl(bin_file=ap_file, scratch_dir=self.scratch_folder_run, delete=False)
except ImportError:
else:
command2run = f"{self.SHELL_SCRIPT} {ap_file} {self.scratch_folder_run}"
_logger.info(command2run)
process = subprocess.Popen(
Expand All @@ -762,7 +756,7 @@ def _run_iblsort(self, ap_file):
log = fid.read()
_logger.error(log)
break
raise RuntimeError(f"{self.SPIKE_SORTER_NAME} {info_str}, {error_str}")
raise RuntimeError(f"{self._sortername} {info_str}, {error_str}")
shutil.copytree(self.scratch_folder_run.joinpath('output'), sorter_dir, dirs_exist_ok=True)
return sorter_dir

Expand All @@ -783,7 +777,7 @@ def _run(self):
out_files = []
sorter_dir = self._run_iblsort(ap_file) # runs the sorter, skips if it already ran
# convert the data to ALF in the ./alf/probeXX/SPIKE_SORTER_NAME folder
probe_out_path = self.session_path.joinpath("alf", label, self.SPIKE_SORTER_NAME)
probe_out_path = self.session_path.joinpath("alf", label, self._sortername)
shutil.rmtree(probe_out_path, ignore_errors=True)
probe_out_path.mkdir(parents=True, exist_ok=True)
ibllib.ephys.spikes.ks2_to_alf(
Expand All @@ -793,9 +787,9 @@ def _run(self):
bin_file=ap_file,
ampfactor=self._sample2v(ap_file),
)
logfile = sorter_dir.joinpath(f"_ibl_log.info_{self.SPIKE_SORTER_NAME}.log")
logfile = sorter_dir.joinpath(f"_ibl_log.info_{self._sortername}.log")
if logfile.exists():
shutil.copyfile(logfile, probe_out_path.joinpath(f"_ibl_log.info_{self.SPIKE_SORTER_NAME}.log"))
shutil.copyfile(logfile, probe_out_path.joinpath(f"_ibl_log.info_{self._sortername}.log"))
# recover the QC files from the spike sorting output and copy them
for file_qc in sorter_dir.glob('_iblqc_*.npy'):
shutil.move(file_qc, file_qc_out := ap_file.parent.joinpath(file_qc.name))
Expand All @@ -809,7 +803,7 @@ def _run(self):
# convert ks2_output into tar file and also register
# Make this in case spike sorting is in old raw_ephys_data folders, for new
# sessions it should already exist
tar_dir = self.session_path.joinpath('spike_sorters', self.SPIKE_SORTER_NAME, label)
tar_dir = self.session_path.joinpath('spike_sorters', self._sortername, label)
tar_dir.mkdir(parents=True, exist_ok=True)
out = ibllib.ephys.spikes.ks2_to_tar(sorter_dir, tar_dir, force=self.FORCE_RERUN)
out_files.extend(out)
Expand Down
Loading
Loading