.. code:: ipython3 %matplotlib inline %load_ext autoreload %autoreload 2 Handle motion/drift with spikeinterface ======================================= Spikeinterface offers a very flexible framework to handle drift as a preprocessing step. If you want to know more, please read the ``motion_correction`` section of the documentation. Here is a short demo on how to handle drift using the high-level function ``spikeinterface.preprocessing.compute_motion()``. This function takes a preprocessed recording as input and returns a ``motion`` object, which contains the information required to interpolate your recording. You can additionally return a ``motion_info`` object which contains the peaks, peak_locations and parameters used to compute the ``motion`` object by passing ``output_motion_info = True`` to the ``compute_motion`` function. Note that you can alternatively compute the motion correction and interpolate at the same time using the ``spikeinterface.preprocessing.correct_motion()`` function. Internally the function ``compute_motion`` runs the following steps (which can be slow!): :: 1. detect_peaks() 2. localize_peaks() 3. select_peaks() (optional) 4. estimate_motion() All these sub-steps can be run with different methods and have many parameters. The high-level function suggests several predefined “presets” and we will explore them using a very well known public dataset recorded by Nick Steinmetz: `Imposed motion datasets `__ This dataset contains 3 recordings and each recording contains a Neuropixels 1 and a Neuropixels 2 probe. Here we will use *dataset1* with *neuropixel1*. This dataset is the *“hello world”* for drift correction in the spike sorting community! .. code:: ipython3 from pathlib import Path import matplotlib.pyplot as plt import numpy as np import shutil import spikeinterface.full as si from spikeinterface.preprocessing import get_motion_parameters_preset, get_motion_presets .. parsed-literal:: /home/nolanlab/Work/Developing/motion_correct_docs/spikeinterface/.venv/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm .. code:: ipython3 base_folder = Path("/home/nolanlab/Work/Data") dataset_folder = base_folder / "dataset1/NP1" .. code:: ipython3 # read the file raw_rec = si.read_spikeglx(dataset_folder) raw_rec .. raw:: html
SpikeGLXRecordingExtractor: 384 channels - 30.0kHz - 1 segments - 58,715,724 samples - 1,957.19s (32.62 minutes) - int16 dtype - 42.00 GiB
Channel IDs
Annotations
Properties
We preprocess the recording with bandpass filter and a common median reference. Note, that it is better to not whiten the recording before motion estimation to get a better estimate of peak locations! .. code:: ipython3 def preprocess_chain(rec): rec = rec.astype('float32') rec = si.bandpass_filter(rec, freq_min=300.0, freq_max=6000.0) rec = si.common_reference(rec, reference="global", operator="median") return rec .. code:: ipython3 rec = preprocess_chain(raw_rec) .. code:: ipython3 job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True) Correcting for drift is easy! You just need to run a single function. We will try this function with some presets. Internally a preset is a dictionary of dictionaries containing all parameters for every steps. Here we also save the motion correction results into a folder to be able to load them later. Preset and parameters ~~~~~~~~~~~~~~~~~~~~~ Motion correction has some steps and every step can be controlled by a method and related parameters. A preset is a nested dict that contains theses methods/parameters. .. code:: ipython3 preset_keys = get_motion_presets() preset_keys .. parsed-literal:: ['dredge', 'medicine', 'dredge_fast', 'nonrigid_accurate', 'nonrigid_fast_and_accurate', 'rigid_fast', 'kilosort_like'] .. code:: ipython3 one_preset_params = get_motion_parameters_preset("kilosort_like") one_preset_params .. parsed-literal:: {'doc': 'Mimic the drift correction of kilosort (grid_convolution + iterative_template)', 'detect_kwargs': {'peak_sign': 'neg', 'detect_threshold': 8.0, 'exclude_sweep_ms': 0.1, 'radius_um': 50, 'noise_levels': None, 'random_chunk_kwargs': {}, 'method': 'locally_exclusive'}, 'select_kwargs': {}, 'localize_peaks_kwargs': {'radius_um': 40.0, 'upsampling_um': 5.0, 'sigma_ms': 0.25, 'margin_um': 50.0, 'prototype': None, 'percentile': 5.0, 'peak_sign': 'neg', 'weight_method': {'mode': 'gaussian_2d', 'sigma_list_um': array([ 5., 10., 15., 20., 25.])}, 'method': 'grid_convolution'}, 'estimate_motion_kwargs': {'direction': 'y', 'rigid': False, 'win_shape': 'rect', 'win_step_um': 200.0, 'win_scale_um': 400.0, 'win_margin_um': None, 'bin_um': 10.0, 'hist_margin_um': 0, 'bin_s': 2.0, 'num_amp_bins': 20, 'num_shifts_global': 15, 'num_iterations': 10, 'num_shifts_block': 5, 'smoothing_sigma': 0.5, 'kriging_sigma': 1, 'kriging_p': 2, 'kriging_d': 2, 'method': 'iterative_template'}, 'interpolate_motion_kwargs': {'border_mode': 'force_extrapolate', 'spatial_interpolation_method': 'kriging', 'sigma_um': 20.0, 'p': 2}} Run motion correction with one function! ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Here we also save the motion correction results into a folder to be able to load them later. .. code:: ipython3 # lets try theses presets some_presets = ("rigid_fast", "kilosort_like", "nonrigid_accurate", "nonrigid_fast_and_accurate", "dredge", "dredge_fast", "medicine") .. code:: ipython3 # compute motion with theses presets for preset in some_presets: print("Computing with", preset) folder = base_folder / "motion_folder_dataset1" / preset if folder.exists(): shutil.rmtree(folder) motion, motion_info = si.compute_motion( rec, preset=preset, folder=folder, output_motion_info=True, **job_kwargs ) .. parsed-literal:: Computing with rigid_fast .. parsed-literal:: noise_level (no parallelization): 0%| | 0/20 [00:00 0 Cross correlation: 100%|██████████| 1/1 [00:05<00:00, 5.33s/it] .. parsed-literal:: Computing with kilosort_like .. parsed-literal:: detect and localize (workers: 10 processes): 100%|██████████| 1958/1958 [04:33<00:00, 7.16it/s] .. parsed-literal:: Computing with nonrigid_accurate .. parsed-literal:: detect and localize (workers: 10 processes): 100%|██████████| 1958/1958 [04:47<00:00, 6.82it/s] pairwise displacement: 100%|██████████| 18/18 [01:01<00:00, 3.43s/it] .. parsed-literal:: Computing with nonrigid_fast_and_accurate .. parsed-literal:: detect and localize (workers: 10 processes): 100%|██████████| 1958/1958 [04:15<00:00, 7.67it/s] pairwise displacement: 100%|██████████| 18/18 [01:00<00:00, 3.37s/it] .. parsed-literal:: Computing with dredge .. parsed-literal:: detect and localize (workers: 10 processes): 100%|██████████| 1958/1958 [04:45<00:00, 6.87it/s] Cross correlation: 100%|██████████| 9/9 [01:33<00:00, 10.35s/it] Solve: 100%|██████████| 8/8 [00:30<00:00, 3.85s/it] .. parsed-literal:: Computing with dredge_fast .. parsed-literal:: detect and localize (workers: 10 processes): 100%|██████████| 1958/1958 [04:13<00:00, 7.72it/s] Cross correlation: 100%|██████████| 9/9 [01:29<00:00, 9.94s/it] Solve: 100%|██████████| 8/8 [00:30<00:00, 3.85s/it] .. parsed-literal:: Computing with medicine .. parsed-literal:: detect and localize (workers: 10 processes): 100%|██████████| 1958/1958 [06:52<00:00, 4.75it/s] [INFO] - Constructing Dataset instance [INFO] - Constructing MotionFunction instance with parameters: bound_normalized = 0.10328739313377586 time_range = (np.float64(0.0005666666666666667), np.float64(1957.1912)) time_bin_size = 1.0 time_kernel_width = 30 num_depth_bins = 13 [INFO] - Constructing ActivityNetwork instance with parameters: hidden_features = (256, 256) activation = None [INFO] - Constructing Medicine instance .. parsed-literal:: num_depth_bins 13 .. parsed-literal:: [INFO] - Fitting motion estimation 100%|██████████| 10000/10000 [02:49<00:00, 59.09it/s] [INFO] - Finished fitting motion estimation Plot the results ~~~~~~~~~~~~~~~~ We load back the results and use the widgets module to explore the estimated drift motion. For all methods we have 4 plots: * top left: time vs estimated peak * top right: time vs peak depth after motion correction * bottom left: the average motion vector across depths and all motion across spatial depths (for non-rigid estimation) * bottom right: if motion correction is non rigid, the motion vector across depths is plotted as a map, with the color code representing the motion in micrometers. A few comments on the figures: * the preset **‘rigid_fast’** has only one motion vector for the entire probe because it is a “rigid” case. The motion amplitude is globally underestimated because it averages across depths. However, the corrected peaks are flatter than the non-corrected ones, so the job is partially done. The big jump at=600s when the probe start moving is recovered quite well. * The preset **kilosort_like** gives better results because it is a non-rigid case. The motion vector is computed for different depths. The corrected peak locations are flatter than the rigid case. The motion vector map is still be a bit noisy at some depths (e.g around 1000um). * The preset **dredge** is official DREDge re-implementation in spikeinterface. It give the best result : very fast and smooth motion estimation. Very few noise. This method also capture very well the non rigid motion gradient along the probe. The best method on the market at the moement. An enormous thanks to the dream team : Charlie Windolf, Julien Boussard, Erdem Varol, Liam Paninski. Note that in the first part of the recording before the imposed motion (0-600s) we clearly have a non-rigid motion: the upper part of the probe (2000-3000um) experience some drifts, but the lower part (0-1000um) is relatively stable. The method defined by this preset is able to capture this. * The preset **nonrigid_accurate** this is the ancestor of “dredge” before it was published. It seems to give good results on this recording but with bit more noise. * The preset **dredge_fast** similar than dredge but faster (using grid_convolution). * The preset **nonrigid_fast_and_accurate** a variant of nonrigid_accurate but faster (using grid_convolution). .. code:: ipython3 for preset in some_presets: # load folder = base_folder / "motion_folder_dataset1" / preset motion_info = si.load_motion_info(folder) # and plot fig = plt.figure(figsize=(14, 8)) si.plot_motion_info( motion_info, rec, figure=fig, depth_lim=(400, 600), color_amplitude=True, amplitude_cmap="inferno", scatter_decimate=10, ) fig.suptitle(f"{preset=}") .. image:: handle_drift_files/handle_drift_17_0.png .. image:: handle_drift_files/handle_drift_17_1.png .. image:: handle_drift_files/handle_drift_17_2.png .. image:: handle_drift_files/handle_drift_17_3.png .. image:: handle_drift_files/handle_drift_17_4.png .. image:: handle_drift_files/handle_drift_17_5.png .. image:: handle_drift_files/handle_drift_17_6.png Make an interpolated recording ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Once you have analyzed your results you can choose the motion correction method that works best on your dataset, and create an interpolated recording using ``interpolate_motion``. The motion object itself is contained in the ``motion_info`` dict. Suppose we decide to use the ``nonrigid_accurate`` preset to make the interpolated recording. We do this as follows .. code:: ipython3 from spikeinterface.sortingcomponents.motion import interpolate_motion preset = "nonrigid_accurate" folder = base_folder / "motion_folder_dataset1" / preset motion_info = si.load_motion_info(folder) motion = motion_info['motion'] interpolated_recording = interpolate_motion(recording=rec, motion=motion) interpolated_recording .. raw:: html
interpolate_motion (InterpolateMotionRecording): 374 channels - 30.0kHz - 1 segments - 58,715,724 samples - 1,957.19s (32.62 minutes) - float32 dtype - 81.81 GiB
Channel IDs
Annotations
Properties
Parent
You can then use the interpolated recording for e.g. spike sorting. Plot peak localization ~~~~~~~~~~~~~~~~~~~~~~ We can also use the internal extra results (peaks and peaks location) to check if putative clusters have a lower spatial spread after the motion correction. Here we plot the estimated peak locations (left) and the corrected peak locations (on right) on top of the probe. The color codes for the peak amplitudes. We can see here that some clusters seem to be more compact on the ‘y’ axis, especially for the preset “nonrigid_accurate”. Be aware that there are two ways to correct for the motion: 1. Interpolate traces and detect/localize peaks again (``interpolate_recording()``) 2. Compensate for drifts directly on peak locations (``correct_motion_on_peaks()``) Case 1 is used before running a spike sorter and the case 2 is used here to display the results. .. code:: ipython3 from spikeinterface.sortingcomponents.motion import correct_motion_on_peaks for preset in some_presets: folder = base_folder / "motion_folder_dataset1" / preset motion_info = si.load_motion_info(folder) motion = motion_info["motion"] fig, axs = plt.subplots(ncols=2, figsize=(12, 8), sharey=True) ax = axs[0] si.plot_probe_map(rec, ax=ax) peaks = motion_info["peaks"] sr = rec.get_sampling_frequency() time_lim0 = 750.0 time_lim1 = 1500.0 mask = (peaks["sample_index"] > int(sr * time_lim0)) & (peaks["sample_index"] < int(sr * time_lim1)) sl = slice(None, None, 5) amps = np.abs(peaks["amplitude"][mask][sl]) amps /= np.quantile(amps, 0.95) c = plt.get_cmap("inferno")(amps) color_kargs = dict(alpha=0.2, s=2, c=c) peak_locations = motion_info["peak_locations"] # color='black', ax.scatter(peak_locations["x"][mask][sl], peak_locations["y"][mask][sl], **color_kargs) peak_locations2 = correct_motion_on_peaks(peaks, peak_locations, motion,rec) ax = axs[1] si.plot_probe_map(rec, ax=ax) # color='black', ax.scatter(peak_locations2["x"][mask][sl], peak_locations2["y"][mask][sl], **color_kargs) ax.set_ylim(400, 600) fig.suptitle(f"{preset=}") .. image:: handle_drift_files/handle_drift_22_0.png .. image:: handle_drift_files/handle_drift_22_1.png .. image:: handle_drift_files/handle_drift_22_2.png .. image:: handle_drift_files/handle_drift_22_3.png .. image:: handle_drift_files/handle_drift_22_4.png .. image:: handle_drift_files/handle_drift_22_5.png .. image:: handle_drift_files/handle_drift_22_6.png run times --------- Presets and related methods have different accuracies but also computation speeds. It is good to have this in mind! .. code:: ipython3 run_times = [] for preset in some_presets: folder = base_folder / "motion_folder_dataset1" / preset motion_info = si.load_motion_info(folder) run_times.append(motion_info["run_times"]) keys = run_times[0].keys() bottom = np.zeros(len(run_times)) fig, ax = plt.subplots(figsize=(14, 6)) for k in keys: rtimes = np.array([rt[k] for rt in run_times]) if np.any(rtimes > 0.0): ax.bar(some_presets, rtimes, bottom=bottom, label=k) bottom += rtimes ax.legend() .. parsed-literal:: .. image:: handle_drift_files/handle_drift_24_1.png