{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Based Curation Tutorial\n", "#### This notebook provides a step-by-step guide on how to use machine learning classifier for curating spike sorted output into 'noise', multi-unit-activity' (MUA) and 'single-unit-activity' (SUA) using Spikeinterface. \n", "\n", "\n", "#### To use this, you should have already done spike sorting.\n", "#### In this notebook we will compute quality metrics and load machine model to predict curation labels for previously uncurated electrophysiology data.\n", "#### The classifier is trained on Neuropixels data from 11 mice recorded in V1,SC and ALM\n", "\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of cores set to: 23\n" ] } ], "source": [ "from pathlib import Path\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import spikeinterface as si\n", "import spikeinterface.extractors as se\n", "import spikeinterface.postprocessing as spost\n", "import spikeinterface.qualitymetrics as sqm\n", "import os\n", "from os import cpu_count\n", "import json\n", "# Set the number of CPU cores to be used globally - defaults to all cores -1\n", "n_cores = cpu_count() -1\n", "si.set_global_job_kwargs(n_jobs = n_cores)\n", "print(f\"Number of cores set to: {n_cores}\")\n", "\n", "# SET OUTPUT FOLDER\n", "output_folder = Path(r\"E:\\spikeinterface_outputs\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import spikeinterface\n", "print(spikeinterface.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load data " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## First step: Loading the recording and sorting objects" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the tutorial, we are using simulated data to create recording and sorting objects." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "# Generate a synthetic recording\n", "recording = si.generate_recording(num_channels=50, sampling_frequency=30000.,\n", " durations=[30], set_probe=True)\n", "# load your recoring depeding on the acquistion software you used, for example:\n", "# recording = se.read_spikeglx(recording_path, stream_name='imec0.ap')\n", "\n", "# Generate a synthetic sorting\n", "sorting = si.generate_sorting(num_units=100, sampling_frequency=30000., durations=[30],\n", " firing_rates=15, refractory_period_ms=1.5)\n", "# load your sorting depeding on the which spike sorter you used, for example:\n", "# sorting = se.read_kilosort(folder_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Second step: Create SortingAnalyzer " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SortingAnalyzer is a postprocessing module which is used to compute quality metrics and template metrics.\n", "To know more about it, please refer to : https://spikeinterface.readthedocs.io/en/latest/modules/postprocessing.html\n", "\n", "If you have already have WaveformExtractor from previous run, you can use it to create a SortingAnalyzer. \n", "Please refer to: https://spikeinterface.readthedocs.io/en/latest/tutorials/waveform_extractor_to_sorting_analyzer.html\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "analyzer = si.create_sorting_analyzer(sorting = sorting, recording = recording, sparse = True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Third step: Compute metrics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Quality metrics serve as features for the machine learning model. We generate these metrics and use the model to predict labels based on them." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# Defines a function to compute all analyzer properties and quality metrics\n", "# Note: this can be a time-consuming step, especially computing PCA-based metrics for long recordings\n", "\n", "def compute_all_metrics(analyzer):\n", "\n", " # Compute required extensions for quality metrics\n", " analyzer.compute({\n", "\t'noise_levels': {},\n", "\t'random_spikes': {'max_spikes_per_unit': 1_000},\n", "\t'templates': {'ms_before': 1.5, 'ms_after': 3.5},\n", "\t'spike_amplitudes': {},\n", "\t'waveforms': {},\n", "\t'principal_components': {},\n", "\t'spike_locations': {},\n", "\t'unit_locations': {},\n", "\t})\n", "\n", " # Compute all available quality metrics\n", " analyzer.compute(\"quality_metrics\", metric_names = sqm.get_quality_metric_list() + sqm.get_quality_pca_metric_list())\n", " analyzer.compute(\"template_metrics\", metric_names = spost.get_template_metric_names())\n", "\n", "\t# Make metric dataframe\n", " quality_metrics = analyzer.extensions['quality_metrics'].data[\"metrics\"]\n", " template_metrics = analyzer.extensions['template_metrics'].data[\"metrics\"]\n", " calculated_metrics = pd.concat([quality_metrics, template_metrics], axis = 1)\n", "\n", " return calculated_metrics" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Compute all metrics\n", "metrics = compute_all_metrics(analyzer)\n", "metrics.index.name = 'cluster_id'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# save the analyzer\n", "analyzer.save_as(folder=output_folder / 'sorting_analyzer', format=\"binary_folder\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load a pretrained model from the Hugging Face Hub" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## First we use noise vs neuron classifier" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "analyzer = si.load_sorting_analyzer(folder= output_folder / 'sorting_analyzer', format=\"binary_folder\")" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\jain\\Documents\\Github_extend\\spikeinterface\\src\\spikeinterface\\curation\\model_based_curation.py:353: UserWarning: No 'model_info.json' file found in folder. No metadata can be checked.\n", " warnings.warn(\"No 'model_info.json' file found in folder. No metadata can be checked.\")\n" ] } ], "source": [ "import skops.io as sio\n", "from huggingface_hub import hf_hub_download\n", "\n", "model_path = hf_hub_download(\"AnoushkaJain3/curation_machine_learning_models\", \"noise_neuron_model.skops\")\n", "untrusted_types = sio.get_untrusted_types(file=model_path)\n", "\n", "# Load pretrained noise/neural activity model and predict on unlabelled data\n", "from spikeinterface.curation.model_based_curation import auto_label_units\n", "\n", "label_conversion = {1: 'noise',0: 'neuron'}\n", "\n", "label_dict_noise = auto_label_units(sorting_analyzer = analyzer,\n", " repo_id = \"AnoushkaJain3/curation_machine_learning_models\",\n", " model_name= 'noise_neuron_model.skops',\n", " label_conversion=label_conversion,\n", " trusted= untrusted_types\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Second we use sua vs mua classifier" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\jain\\Documents\\Github_extend\\spikeinterface\\src\\spikeinterface\\curation\\model_based_curation.py:353: UserWarning: No 'model_info.json' file found in folder. No metadata can be checked.\n", " warnings.warn(\"No 'model_info.json' file found in folder. No metadata can be checked.\")\n" ] } ], "source": [ "model_path = hf_hub_download(\"AnoushkaJain3/curation_machine_learning_models\", \"sua_mua_model.skops\")\n", "untrusted_types = sio.get_untrusted_types(file=model_path)\n", "\n", "label_conversion = {1: 'sua',0: 'mua'}\n", "\n", "label_dict_sua= auto_label_units(sorting_analyzer = analyzer,\n", " repo_id = \"AnoushkaJain3/curation_machine_learning_models\",\n", " model_name= 'sua_mua_model.skops',\n", " label_conversion=label_conversion,\n", " trusted= untrusted_types\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Combine the predictions and probabilities from the two models" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{0: ('noise', 0.6136904761904761), 1: ('noise', 0.5225), 2: ('noise', 0.6099285714285715), 3: ('noise', 0.6553571428571427), 4: ('noise', 0.6146904761904761), 5: ('noise', 0.6320238095238095), 6: ('noise', 0.6036904761904761), 7: ('noise', 0.6036904761904761), 8: ('noise', 0.7125), 9: ('noise', 0.6025), 10: ('noise', 0.5891666666666666), 11: ('noise', 0.6911666666666666), 12: ('noise', 0.7113571428571429), 13: ('noise', 0.6826904761904761), 14: ('noise', 0.6248333333333334), 15: ('noise', 0.6308333333333335), 16: ('noise', 0.5936904761904762), 17: ('noise', 0.6193571428571428), 18: ('noise', 0.5803571428571428), 19: ('noise', 0.5836904761904761), 20: ('noise', 0.6753571428571428), 21: ('noise', 0.5824999999999999), 22: ('noise', 0.6633333333333333), 23: ('noise', 0.6806904761904762), 24: ('noise', 0.6986904761904762), 25: ('noise', 0.7125), 26: ('noise', 0.6036904761904761), 27: ('noise', 0.6633571428571429), 28: ('noise', 0.6958333333333333), 29: ('noise', 0.6136904761904761), 30: ('noise', 0.6093571428571428), 31: ('noise', 0.5338333333333334), 32: ('noise', 0.5577380952380951), 33: ('noise', 0.6425), 34: ('noise', 0.5836904761904761), 35: ('noise', 0.5910714285714286), 36: ('noise', 0.5458333333333333), 37: ('noise', 0.6320238095238095), 38: ('noise', 0.6353571428571428), 39: ('noise', 0.6591666666666667), 40: ('noise', 0.6453571428571427), 41: ('noise', 0.6508333333333334), 42: ('noise', 0.6420238095238094), 43: ('noise', 0.7020238095238095), 44: ('noise', 0.6541666666666666), 45: ('noise', 0.6625), 46: ('noise', 0.6986904761904762), 47: ('noise', 0.6886904761904762), 48: ('noise', 0.6236904761904761), 49: ('noise', 0.6008333333333333), 50: ('noise', 0.6786904761904762), 51: ('noise', 0.6276666666666667), 52: ('noise', 0.5775), 53: ('noise', 0.7025), 54: ('noise', 0.6225), 55: ('noise', 0.5825), 56: ('noise', 0.6708333333333334), 57: ('noise', 0.6203571428571428), 58: ('noise', 0.6753571428571428), 59: ('noise', 0.6653571428571428), 60: ('noise', 0.7226904761904762), 61: ('noise', 0.6225), 62: ('noise', 0.6753571428571428), 63: ('noise', 0.6548333333333334), 64: ('noise', 0.7358333333333333), 65: ('noise', 0.6593571428571429), 66: ('noise', 0.6600238095238095), 67: ('noise', 0.5974999999999999), 68: ('noise', 0.6036904761904761), 69: ('noise', 0.6386904761904761), 70: ('noise', 0.6886904761904762), 71: ('noise', 0.6403571428571427), 72: ('noise', 0.6586904761904762), 73: ('noise', 0.7053571428571429), 74: ('noise', 0.5651666666666667), 75: ('noise', 0.5941666666666666), 76: ('noise', 0.5975), 77: ('noise', 0.6986904761904762), 78: ('noise', 0.7058333333333333), 79: ('noise', 0.6830238095238096), 80: ('noise', 0.6886904761904762), 81: ('noise', 0.6796666666666666), 82: ('noise', 0.6886904761904762), 83: ('noise', 0.6141666666666666), 84: ('noise', 0.5966666666666667), 85: ('noise', 0.7021666666666667), 86: ('noise', 0.6608333333333334), 87: ('noise', 0.6926904761904763), 88: ('noise', 0.6720238095238095), 89: ('noise', 0.5858333333333333), 90: ('noise', 0.6425), 91: ('noise', 0.5908333333333333), 92: ('noise', 0.6175), 93: ('noise', 0.6170238095238095), 94: ('noise', 0.6308333333333332), 95: ('noise', 0.6041666666666666), 96: ('noise', 0.6236904761904761), 97: ('noise', 0.5891666666666666), 98: ('noise', 0.6248333333333334), 99: ('noise', 0.6805000000000001)}\n" ] } ], "source": [ "# Combine data\n", "decoder_output = {}\n", "\n", "for key, value in label_dict_noise.items():\n", " label, prob1 = value\n", " if label == 'noise':\n", " decoder_output[key] = (label, prob1) # Keep 'noise' and its probability\n", " else:\n", " sua_mua_label, prob2 = label_dict_sua[key] # Get 'sua' or 'mua' label and its probability\n", " decoder_output[key] = (sua_mua_label, prob2) # Replace 'neuron' with 'sua'/'mua' and their probabilities\n", "\n", "# Print the resulting combined data\n", "print(decoder_output)\n", "\n", "# Save the decoder output\n", "with open(output_folder / 'decoder_output.json', 'w') as f:\n", " json.dump(decoder_output, f)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Assess model performance by comparing with human labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To assess the performance of the model relative to human labels, we can load (or here generate randomly) some labels, and plot a confusion matrix of predicted vs human labels for all clusters" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import confusion_matrix, balanced_accuracy_score\n", "import seaborn as sns\n", "\n", "# Use 'ground-truth' labels to check prediction accuracy\n", "# These are assigned randomly here but you could load these from phy 'cluster_group.tsv', from the 'quality' property of the sorting, or similar\n", "human_labels = np.random.choice(list(label_conversion.values()), analyzer.get_num_units())\n", "\n", "# Get labels from phy sorting (if loaded) using:\n", "# human_labels = unlabelled_analyzer.sorting.get_property('quality')\n", "\n", "predictions = analyzer.sorting.get_property('label_prediction')\n", "\n", "conf_matrix = confusion_matrix(human_labels, predictions)\n", "\n", "# Calculate balanced accuracy for the confusion matrix\n", "balanced_accuracy = balanced_accuracy_score(human_labels, predictions)\n", "\n", "sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='viridis')\n", "plt.xlabel('Predicted Label')\n", "plt.ylabel('Human Label')\n", "plt.xticks(ticks = [0.5, 1.5], labels = list(label_conversion.values()))\n", "plt.yticks(ticks = [0.5, 1.5], labels = list(label_conversion.values()))\n", "plt.title('Predicted vs Human Label')\n", "plt.suptitle(f\"Balanced Accuracy: {balanced_accuracy}\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also see how the model's confidence relates to the probability that the model label matches the human label\n", "\n", "This could be used to set a threshold above which you might accept the model's classification, and only manually curate those which it is less sure of" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "confidences = analyzer.sorting.get_property('label_confidence')\n", "\n", "# Make dataframe of human label, model label, and confidence\n", "label_df = pd.DataFrame(data = {\n", " 'phy_label': human_labels,\n", " 'decoder_label': predictions,\n", " 'confidence': confidences},\n", " index = decoder_output.keys())\n", "\n", "# Calculate the proportion of agreed labels by confidence decile\n", "label_df['model_x_human_agreement'] = label_df['phy_label'] == label_df['decoder_label']\n", "\n", "def calculate_moving_avg(label_df, confidence_label, window_size):\n", "\n", " label_df[f'{confidence_label}_decile'] = pd.cut(label_df[confidence_label], 10, labels=False, duplicates='drop')\n", " # Group by decile and calculate the proportion of correct labels (agreement)\n", " p_label_grouped = label_df.groupby(f'{confidence_label}_decile')['model_x_human_agreement'].mean()\n", " # Convert decile to range 0-1\n", " p_label_grouped.index = p_label_grouped.index / 10\n", " # Sort the DataFrame by confidence scores\n", " label_df_sorted = label_df.sort_values(by=confidence_label)\n", "\n", " p_label_moving_avg = label_df_sorted['model_x_human_agreement'].rolling(window=window_size).mean()\n", "\n", " return label_df_sorted[confidence_label], p_label_moving_avg\n", "\n", "p_agreement_sorted, p_agreement_moving_avg = calculate_moving_avg(label_df, 'confidence', 20)\n", "\n", "# Plot the moving average of agreement\n", "plt.figure(figsize=(6, 6))\n", "plt.plot(p_agreement_sorted, p_agreement_moving_avg, label = 'Moving Average')\n", "plt.axhline(y=1/len(np.unique(predictions)), color='black', linestyle='--', label='Chance')\n", "plt.xlabel('Confidence'); plt.xlim(0.5, 1)\n", "plt.ylabel('Proportion Agreement with Human Label'); plt.ylim(0, 1)\n", "plt.title('Agreement vs Confidence (Moving Average)')\n", "plt.legend(); plt.grid(True); plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# If you only have metrics files" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you don't have access to to sorting analyzer for a particular recording, you can still use the pretrained model to predict on a new data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Load your files\n", "\n", "data_directory = Path(r\"Y:\\invivo_ephys\\Neuropixels\")\n", "recording_folder = '2838_20240620'\n", "metrics_folder = data_directory / recording_folder / 'metrics'\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# load quality metrics and template metrics csv files\n", "\n", "quality_metrics = pd.read_csv(metrics_folder / 'quality_metrics.csv')\n", "template_metrics = pd.read_csv(metrics_folder / 'template_metrics.csv')\n", "\n", "# merge them based on the index\n", "metrics = pd.merge(quality_metrics, template_metrics, left_index=True, right_index=True)\n", "\n", "# Replace infinities with NaNs and convert to float32\n", "metrics_dataframe = metrics.applymap(lambda x: np.nan if np.isinf(x) else x).astype(\"float32\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## First we use noise vs neuron classifier" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import skops.io as sio\n", "from huggingface_hub import hf_hub_download\n", "\n", "json_path = hf_hub_download(repo_id=\"AnoushkaJain3/curation_machine_learning_models\", filename=\"features_sequence.json\")\n", "\n", "# Load the JSON file\n", "with open(json_path, 'r') as file:\n", " column_list = json.load(file)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "model_path_noise = hf_hub_download(repo_id=\"AnoushkaJain3/curation_test\", filename=\"noise_neuron_model.skops\")\n", "\n", "# Get the untrusted types from the skops file\n", "untrusted_types = sio.get_untrusted_types(file = model_path_noise)\n", "clf = sio.load(model_path_noise, trusted=untrusted_types)\n", "\n", "# Perform prediction\n", "noise_predictions = clf.predict(metrics_dataframe[column_list])\n", "noise_probabs = clf.predict_proba(metrics_dataframe[column_list])\n", "\n", "metrics_dataframe['noise_label'] = noise_predictions\n", "metrics_dataframe['noise_probs'] = noise_probabs[:,1]\n", "metrics_dataframe['noise_label'] = metrics_dataframe['noise_label'].map({1: 'noise', 0: 'neural'})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## First we use sua vs mua classifier" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "model_path_sua = hf_hub_download(repo_id=\"AnoushkaJain3/curation_test\", filename=\"sua_mua_model.skops\")\n", "\n", "# Get the untrusted types from the skops file\n", "untrusted_types = sio.get_untrusted_types(file = model_path_sua)\n", "clf = sio.load(model_path_sua, trusted=untrusted_types)\n", "\n", "# Perform prediction\n", "sua_predictions = clf.predict(metrics_dataframe[column_list])\n", "sua_probabs = clf.predict_proba(metrics_dataframe[column_list])\n", "metrics_dataframe['sua_label'] = sua_predictions\n", "metrics_dataframe['sua_probs'] = sua_probabs[:,1]\n", "metrics_dataframe['sua_label'] = metrics_dataframe['sua_label'].map({1 : 'sua', 0 :'mua'})" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# Combine the predictions and probabilities from the two models\n", "\n", "# Create the 'decoder_label' column\n", "metrics_dataframe['decoder_label'] = metrics_dataframe.apply(lambda row: row['noise_label'] if row['noise_label'] == 'noise' else row['sua_label'], axis=1)\n", " \n", "# Create the 'decoder_probs' column\n", "metrics_dataframe['decoder_probs'] = metrics_dataframe.apply(lambda row: row['noise_probs'] if row['decoder_label'] == 'noise' else row['sua_probs'], axis=1)\n", "\n", "# Update 'decoder_probs' for 'mua' values\n", "metrics_dataframe.loc[metrics_dataframe['decoder_label'] == 'mua', 'decoder_probs'] = 1 - metrics_dataframe['decoder_probs']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "metrics_dataframe['decoder_label']" ] } ], "metadata": { "kernelspec": { "display_name": "spike_interface", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 2 }