.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/curation/plot_2_train_a_model.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_curation_plot_2_train_a_model.py: Training a model for automated curation ======================================= If the pretrained models do not give satisfactory performance on your data, it is easy to train your own classifier using SpikeInterface. .. GENERATED FROM PYTHON SOURCE LINES 10-14 Step 1: Generate and label data ------------------------------- First we will import our dependencies .. GENERATED FROM PYTHON SOURCE LINES 14-28 .. code-block:: Python import warnings warnings.filterwarnings("ignore") from pathlib import Path import numpy as np import pandas as pd import matplotlib.pyplot as plt import spikeinterface.core as si import spikeinterface.curation as sc import spikeinterface.widgets as sw # Note, you can set the number of cores you use using e.g. # si.set_global_job_kwargs(n_jobs = 8) .. GENERATED FROM PYTHON SOURCE LINES 29-39 For this tutorial, we will use simulated data to create ``recording`` and ``sorting`` objects. We'll create two sorting objects: :code:`sorting_1` is coupled to the real recording, so the spike times of the sorter will perfectly match the spikes in the recording. Hence this will contain good units. However, we've uncoupled :code:`sorting_2` to the recording and the spike times will not be matched with the spikes in the recording. Hence these units will mostly be random noise. We'll combine the "good" and "noise" sortings into one sorting object using :code:`si.aggregate_units`. (When making your own model, you should `load your own recording `_ and `do a sorting `_ on your data.) .. GENERATED FROM PYTHON SOURCE LINES 39-45 .. code-block:: Python recording, sorting_1 = si.generate_ground_truth_recording(num_channels=4, seed=1, num_units=5) _, sorting_2 =si.generate_ground_truth_recording(num_channels=4, seed=2, num_units=5) both_sortings = si.aggregate_units([sorting_1, sorting_2]) .. GENERATED FROM PYTHON SOURCE LINES 46-48 To do some visualisation and postprocessing, we need to create a sorting analyzer, and compute some extensions: .. GENERATED FROM PYTHON SOURCE LINES 48-52 .. code-block:: Python analyzer = si.create_sorting_analyzer(sorting = both_sortings, recording=recording) analyzer.compute(['noise_levels','random_spikes','waveforms','templates']) .. rst-class:: sphx-glr-script-out .. code-block:: none estimate_sparsity (no parallelization): 0%| | 0/10 [00:00 .. GENERATED FROM PYTHON SOURCE LINES 60-64 This is as expected: great! (Find out more about plotting `using widgets `_.) We've set up our system so that the first five units are 'good' and the next five are 'bad'. So we can make a list of labels which contain this information. For real data, you could use a manual curation tool to make your own list. .. GENERATED FROM PYTHON SOURCE LINES 64-67 .. code-block:: Python labels = ['good', 'good', 'good', 'good', 'good', 'bad', 'bad', 'bad', 'bad', 'bad'] .. GENERATED FROM PYTHON SOURCE LINES 68-76 Step 2: Train our model ----------------------- We'll now train a model, based on our labelled data. The model will be trained using properties of the units, and then be applied to units from other sortings. The properties we use are the `quality metrics `_ and `template metrics `_. Hence we need to compute these, using some ``sorting_analyzer``` extensions. .. GENERATED FROM PYTHON SOURCE LINES 76-79 .. code-block:: Python analyzer.compute(['spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics']) .. rst-class:: sphx-glr-script-out .. code-block:: none Fitting PCA: 0%| | 0/10 [00:00`_ `imputation strategies `_ and `scalers `_, although the documentation is quite overwhelming. You can find the classifiers we've tried out using the ``sc.get_default_classifier_search_spaces`` function. The above code saves the model in ``model.skops``, some metadata in ``model_info.json`` and the model accuracies in ``model_accuracies.csv`` in the specified ``folder`` (in this case ``'my_folder'``). (``skops`` is a file format: you can think of it as a more-secure pkl file. `Read more `_.) The ``model_accuracies.csv`` file contains the accuracy, precision and recall of the tested models. Let's take a look: .. GENERATED FROM PYTHON SOURCE LINES 123-127 .. code-block:: Python accuracies = pd.read_csv(Path("my_folder") / "model_accuracies.csv", index_col = 0) accuracies.head() .. raw:: html
classifier name imputation_strategy scaling_strategy balanced_accuracy precision recall model_id best_params
0 RandomForestClassifier median StandardScaler() 1.0 1.0 1.0 0 {'n_estimators': 150, 'min_samples_split': 4, ...


.. GENERATED FROM PYTHON SOURCE LINES 128-135 Our model is perfect!! This is because the task was *very* easy. We had 10 units; where half were pure noise and half were not. The model also contains some more information, such as which features are "important", as defined by sklearn (learn about feature importance of a `Random Forest Classifier `_.) We can plot these: .. GENERATED FROM PYTHON SOURCE LINES 135-159 .. code-block:: Python # Plot feature importances importances = best_model.named_steps['classifier'].feature_importances_ indices = np.argsort(importances)[::-1] # The sklearn importances are not computed for inputs whose values are all `nan`. # Hence, we need to pick out the non-`nan` columns of our metrics features = best_model.feature_names_in_ n_features = best_model.n_features_in_ metrics = pd.concat([analyzer.get_extension('quality_metrics').get_data(), analyzer.get_extension('template_metrics').get_data()], axis=1) non_null_metrics = ~(metrics.isnull().all()).values features = features[non_null_metrics] n_features = len(features) plt.figure(figsize=(12, 7)) plt.title("Feature Importances") plt.bar(range(n_features), importances[indices], align="center") plt.xticks(range(n_features), features[indices], rotation=90) plt.xlim([-1, n_features]) plt.subplots_adjust(bottom=0.3) plt.show() .. image-sg:: /tutorials/curation/images/sphx_glr_plot_2_train_a_model_002.png :alt: Feature Importances :srcset: /tutorials/curation/images/sphx_glr_plot_2_train_a_model_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 160-169 Roughly, this means the model is using metrics such as "nn_hit_rate" and "l_ratio" but is not using "sync_spike_4" and "rp_contanimation". This is a toy model, so don't take these results seriously. But using this information, you could retrain another, simpler model using a subset of the metrics, by passing, e.g., ``metric_names = ['nn_hit_rate', 'l_ratio',...]`` to the ``train_model`` function. Now that you have a model, you can `apply it to another sorting `_ or `upload it to HuggingFaceHub `_. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 9.604 seconds) .. _sphx_glr_download_tutorials_curation_plot_2_train_a_model.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_2_train_a_model.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_2_train_a_model.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_2_train_a_model.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_