@@ -498,24 +498,35 @@ def time_to_sample_index(self, time_s, segment_index=None):
498498 rs = self ._recording_segments [segment_index ]
499499 return rs .time_to_sample_index (time_s )
500500
501- def _save (self , format = "binary" , verbose : bool = False , ** save_kwargs ):
501+ def _get_t_starts (self ):
502502 # handle t_starts
503503 t_starts = []
504504 has_time_vectors = []
505- for segment_index , rs in enumerate ( self ._recording_segments ) :
505+ for rs in self ._recording_segments :
506506 d = rs .get_times_kwargs ()
507507 t_starts .append (d ["t_start" ])
508- has_time_vectors .append (d ["time_vector" ] is not None )
509508
510509 if all (t_start is None for t_start in t_starts ):
511510 t_starts = None
511+ return t_starts
512512
513+ def _get_time_vectors (self ):
514+ time_vectors = []
515+ for rs in self ._recording_segments :
516+ d = rs .get_times_kwargs ()
517+ time_vectors .append (d ["time_vector" ])
518+ if all (time_vector is None for time_vector in time_vectors ):
519+ time_vectors = None
520+ return time_vectors
521+
522+ def _save (self , format = "binary" , verbose : bool = False , ** save_kwargs ):
513523 kwargs , job_kwargs = split_job_kwargs (save_kwargs )
514524
515525 if format == "binary" :
516526 folder = kwargs ["folder" ]
517527 file_paths = [folder / f"traces_cached_seg{ i } .raw" for i in range (self .get_num_segments ())]
518528 dtype = kwargs .get ("dtype" , None ) or self .get_dtype ()
529+ t_starts = self ._get_t_starts ()
519530
520531 write_binary_recording (self , file_paths = file_paths , dtype = dtype , verbose = verbose , ** job_kwargs )
521532
@@ -572,11 +583,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
572583 probegroup = self .get_probegroup ()
573584 cached .set_probegroup (probegroup )
574585
575- for segment_index , rs in enumerate ( self ._recording_segments ):
576- d = rs . get_times_kwargs ()
577- time_vector = d [ " time_vector" ]
578- if time_vector is not None :
579- cached ._recording_segments [ segment_index ]. time_vector = time_vector
586+ time_vectors = self ._get_time_vectors ()
587+ if time_vectors is not None :
588+ for segment_index , time_vector in enumerate ( time_vectors ):
589+ if time_vector is not None :
590+ cached .set_times ( time_vector , segment_index = segment_index )
580591
581592 return cached
582593
0 commit comments