{ "cells": [ { "cell_type": "code", "execution_count": 35, "id": "29893746-48a4-4439-ac69-a1514c048653", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T15:43:14.698999Z", "iopub.status.busy": "2024-05-23T15:43:14.698643Z", "iopub.status.idle": "2024-05-23T15:43:23.296300Z", "shell.execute_reply": "2024-05-23T15:43:23.295249Z", "shell.execute_reply.started": "2024-05-23T15:43:14.698975Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pynvml in /usr/local/lib/python3.11/dist-packages (11.5.0)\n", "Requirement already satisfied: numba in /usr/local/lib/python3.11/dist-packages (0.59.1)\n", "Requirement already satisfied: llvmlite<0.43,>=0.42.0dev0 in /usr/local/lib/python3.11/dist-packages (from numba) (0.42.0)\n", "Requirement already satisfied: numpy<1.27,>=1.22 in /usr/local/lib/python3.11/dist-packages (from numba) (1.26.3)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mCollecting evaluate\n", " Downloading evaluate-0.4.2-py3-none-any.whl.metadata (9.3 kB)\n", "Requirement already satisfied: datasets>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from evaluate) (2.14.5)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from evaluate) (1.26.3)\n", "Requirement already satisfied: dill in /usr/local/lib/python3.11/dist-packages (from evaluate) (0.3.7)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from evaluate) (2.2.0)\n", "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.11/dist-packages (from evaluate) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.11/dist-packages (from evaluate) (4.66.1)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.11/dist-packages (from evaluate) (3.4.1)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.11/dist-packages (from evaluate) (0.70.15)\n", "Requirement already satisfied: fsspec>=2021.05.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (2023.6.0)\n", "Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from evaluate) (0.20.3)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from evaluate) (23.2)\n", "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->evaluate) (15.0.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->evaluate) (3.9.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from datasets>=2.0.0->evaluate) (5.4.1)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.7.0->evaluate) (3.13.1)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.9.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests>=2.19.0->evaluate) (3.3)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests>=2.19.0->evaluate) (2020.6.20)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->evaluate) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/lib/python3/dist-packages (from pandas->evaluate) (2022.1)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->evaluate) (2023.4)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (23.1.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (6.0.4)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.9.4)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.4.1)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.1)\n", "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->evaluate) (1.16.0)\n", "Downloading evaluate-0.4.2-py3-none-any.whl (84 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.1/84.1 kB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hInstalling collected packages: evaluate\n", "Successfully installed evaluate-0.4.2\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install pynvml numba\n", "!pip install evaluate" ] }, { "cell_type": "code", "execution_count": 2, "id": "54e5daed-d4ba-4609-bc59-49b799947d95", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T15:16:55.852893Z", "iopub.status.busy": "2024-05-23T15:16:55.851997Z", "iopub.status.idle": "2024-05-23T15:16:55.858275Z", "shell.execute_reply": "2024-05-23T15:16:55.856882Z", "shell.execute_reply.started": "2024-05-23T15:16:55.852865Z" } }, "outputs": [], "source": [ "import random\n", "from collections import Counter\n", "from tqdm import tqdm\n", "\n", "import torch\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns" ] }, { "cell_type": "code", "execution_count": 3, "id": "4bf4396d-8c76-4025-b722-fa981cb4ce09", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T15:17:24.664545Z", "iopub.status.busy": "2024-05-23T15:17:24.663873Z", "iopub.status.idle": "2024-05-23T15:17:25.701955Z", "shell.execute_reply": "2024-05-23T15:17:25.701129Z", "shell.execute_reply.started": "2024-05-23T15:17:24.664519Z" } }, "outputs": [], "source": [ "import torch \n", "import torch.nn as nn\n", "from transformers import EsmTokenizer, EsmForSequenceClassification" ] }, { "cell_type": "code", "execution_count": 4, "id": "40d5d019-676b-4484-bac7-bc9f3bf8e74a", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T15:17:31.921016Z", "iopub.status.busy": "2024-05-23T15:17:31.919848Z", "iopub.status.idle": "2024-05-23T15:17:31.971089Z", "shell.execute_reply": "2024-05-23T15:17:31.970068Z", "shell.execute_reply.started": "2024-05-23T15:17:31.920952Z" } }, "outputs": [], "source": [ "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"" ] }, { "cell_type": "code", "execution_count": 5, "id": "b11908f4-1374-46a8-992f-cf396183cf3a", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T15:17:38.440585Z", "iopub.status.busy": "2024-05-23T15:17:38.439573Z", "iopub.status.idle": "2024-05-23T15:17:38.450735Z", "shell.execute_reply": "2024-05-23T15:17:38.449625Z", "shell.execute_reply.started": "2024-05-23T15:17:38.440537Z" } }, "outputs": [ { "data": { "text/plain": [ "'cuda:0'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device" ] }, { "cell_type": "code", "execution_count": 6, "id": "1d11319a-7704-4484-ab46-33ba2e0e1b9b", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T15:18:00.554601Z", "iopub.status.busy": "2024-05-23T15:18:00.554018Z", "iopub.status.idle": "2024-05-23T15:18:02.702640Z", "shell.execute_reply": "2024-05-23T15:18:02.701955Z", "shell.execute_reply.started": "2024-05-23T15:18:00.554575Z" } }, "outputs": [ { "data": { "text/plain": [ "(True, 1, 0, , 'Quadro P6000')" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "\n", "torch.cuda.is_available(), torch.cuda.device_count(), torch.cuda.current_device(), torch.cuda.device(0), torch.cuda.get_device_name(0)" ] }, { "cell_type": "code", "execution_count": 9, "id": "6690d2ff-cce2-4c1f-af2f-3a37dc18c983", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T15:19:07.447117Z", "iopub.status.busy": "2024-05-23T15:19:07.446269Z", "iopub.status.idle": "2024-05-23T15:19:10.121042Z", "shell.execute_reply": "2024-05-23T15:19:10.119780Z", "shell.execute_reply.started": "2024-05-23T15:19:07.447068Z" } }, "outputs": [], "source": [ "from pynvml import *\n", "\n", "\n", "def print_gpu_utilization():\n", " nvmlInit()\n", " handle = nvmlDeviceGetHandleByIndex(0)\n", " info = nvmlDeviceGetMemoryInfo(handle)\n", " print(f\"GPU memory occupied: {info.used//1024**2} MB.\")\n", "\n", "\n", "def print_summary(result):\n", " print(f\"Time: {result.metrics['train_runtime']:.2f}\")\n", " print(f\"Samples/second: {result.metrics['train_samples_per_second']:.2f}\")\n", " print_gpu_utilization()\n", "\n", "from numba import cuda \n", "device = cuda.get_current_device()\n", "device.reset()\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "8681037d-256c-4aed-ba2b-4856beebb34d", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T15:19:32.734284Z", "iopub.status.busy": "2024-05-23T15:19:32.733624Z", "iopub.status.idle": "2024-05-23T15:19:32.739946Z", "shell.execute_reply": "2024-05-23T15:19:32.738443Z", "shell.execute_reply.started": "2024-05-23T15:19:32.734258Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GPU memory occupied: 136 MB.\n" ] } ], "source": [ "print_gpu_utilization()" ] }, { "cell_type": "code", "execution_count": 12, "id": "e76a2e25-3a04-4af3-93da-c09d7f8de1a1", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T15:20:24.245833Z", "iopub.status.busy": "2024-05-23T15:20:24.244962Z", "iopub.status.idle": "2024-05-23T15:20:29.320493Z", "shell.execute_reply": "2024-05-23T15:20:29.319605Z", "shell.execute_reply.started": "2024-05-23T15:20:24.245791Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "968a6df1a47b41409ab7bf11a3713ec1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/40.0 [00:00" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_loader" ] }, { "cell_type": "code", "execution_count": 30, "id": "cea06b94-3557-44e6-92b5-5f492a13d2b5", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T15:37:00.362739Z", "iopub.status.busy": "2024-05-23T15:37:00.362358Z", "iopub.status.idle": "2024-05-23T15:37:00.369008Z", "shell.execute_reply": "2024-05-23T15:37:00.368203Z", "shell.execute_reply.started": "2024-05-23T15:37:00.362713Z" } }, "outputs": [], "source": [ "from transformers import pipeline\n", "pipeline = pipeline(task=\"text-classification\", model=model, tokenizer=tokenizer, device=device)" ] }, { "cell_type": "code", "execution_count": 32, "id": "1f5cbdd2-70de-4ce1-bf7b-54fa7855f8a9", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T15:38:27.052751Z", "iopub.status.busy": "2024-05-23T15:38:27.052063Z", "iopub.status.idle": "2024-05-23T15:38:50.488290Z", "shell.execute_reply": "2024-05-23T15:38:50.487534Z", "shell.execute_reply.started": "2024-05-23T15:38:27.052725Z" } }, "outputs": [], "source": [ "predictions = pipeline(X_test)" ] }, { "cell_type": "code", "execution_count": 33, "id": "034bb9a9-d9f1-4beb-a8a3-3beb8117abb7", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T15:38:56.411699Z", "iopub.status.busy": "2024-05-23T15:38:56.411333Z", "iopub.status.idle": "2024-05-23T15:38:56.457662Z", "shell.execute_reply": "2024-05-23T15:38:56.456624Z", "shell.execute_reply.started": "2024-05-23T15:38:56.411665Z" } }, "outputs": [ { "data": { "text/plain": [ "[{'label': 'LABEL_1', 'score': 0.9920824766159058},\n", " {'label': 'LABEL_0', 'score': 0.6344404816627502},\n", " {'label': 'LABEL_1', 'score': 0.9963781237602234},\n", " {'label': 'LABEL_1', 'score': 0.8957202434539795},\n", " {'label': 'LABEL_0', 'score': 0.9899581074714661},\n", " {'label': 'LABEL_0', 'score': 0.992432713508606},\n", " {'label': 'LABEL_1', 'score': 0.9936670660972595},\n", " {'label': 'LABEL_0', 'score': 0.9957142472267151},\n", " {'label': 'LABEL_1', 'score': 0.7678724527359009},\n", " {'label': 'LABEL_0', 'score': 0.9950073957443237},\n", " {'label': 'LABEL_0', 'score': 0.9934017062187195},\n", " {'label': 'LABEL_0', 'score': 0.9901747703552246},\n", " {'label': 'LABEL_0', 'score': 0.9942784309387207},\n", " {'label': 'LABEL_1', 'score': 0.9712179899215698},\n", " {'label': 'LABEL_0', 'score': 0.9931582808494568},\n", " {'label': 'LABEL_0', 'score': 0.9989894032478333},\n", " {'label': 'LABEL_0', 'score': 0.9962522387504578},\n", " {'label': 'LABEL_0', 'score': 0.8317552208900452},\n", " {'label': 'LABEL_0', 'score': 0.980193555355072},\n", " {'label': 'LABEL_0', 'score': 0.7707980275154114},\n", " {'label': 'LABEL_0', 'score': 0.8872848153114319},\n", " {'label': 'LABEL_0', 'score': 0.9598217606544495},\n", " {'label': 'LABEL_1', 'score': 0.9657743573188782},\n", " {'label': 'LABEL_1', 'score': 0.9803487658500671},\n", " {'label': 'LABEL_1', 'score': 0.9939471483230591},\n", " {'label': 'LABEL_0', 'score': 0.6964303851127625},\n", " {'label': 'LABEL_1', 'score': 0.9915476441383362},\n", " {'label': 'LABEL_1', 'score': 0.993975818157196},\n", " {'label': 'LABEL_1', 'score': 0.9934781789779663},\n", " {'label': 'LABEL_0', 'score': 0.7910520434379578},\n", " {'label': 'LABEL_1', 'score': 0.9921974539756775},\n", " {'label': 'LABEL_1', 'score': 0.9815288782119751},\n", " {'label': 'LABEL_1', 'score': 0.9945481419563293},\n", " {'label': 'LABEL_1', 'score': 0.9962074756622314},\n", " {'label': 'LABEL_1', 'score': 0.8752002716064453},\n", " {'label': 'LABEL_1', 'score': 0.9936671853065491},\n", " {'label': 'LABEL_0', 'score': 0.9653744697570801},\n", " {'label': 'LABEL_0', 'score': 0.9754136204719543},\n", " {'label': 'LABEL_1', 'score': 0.9944813251495361},\n", " {'label': 'LABEL_1', 'score': 0.9775446653366089},\n", " {'label': 'LABEL_0', 'score': 0.7639273405075073},\n", " {'label': 'LABEL_0', 'score': 0.667346715927124},\n", " {'label': 'LABEL_0', 'score': 0.7142799496650696},\n", " {'label': 'LABEL_1', 'score': 0.9928463101387024},\n", " {'label': 'LABEL_0', 'score': 0.5917879343032837},\n", " {'label': 'LABEL_1', 'score': 0.890602707862854},\n", " {'label': 'LABEL_0', 'score': 0.9238371849060059},\n", " {'label': 'LABEL_1', 'score': 0.9933255910873413},\n", " {'label': 'LABEL_1', 'score': 0.9891610741615295},\n", " {'label': 'LABEL_0', 'score': 0.9632930159568787},\n", " {'label': 'LABEL_0', 'score': 0.9220459461212158},\n", " {'label': 'LABEL_0', 'score': 0.9919925928115845},\n", " {'label': 'LABEL_1', 'score': 0.5947464108467102},\n", " {'label': 'LABEL_0', 'score': 0.9746901988983154},\n", " {'label': 'LABEL_0', 'score': 0.9779336452484131},\n", " {'label': 'LABEL_1', 'score': 0.9941219687461853},\n", " {'label': 'LABEL_1', 'score': 0.8267089128494263},\n", " {'label': 'LABEL_0', 'score': 0.996107280254364},\n", " {'label': 'LABEL_1', 'score': 0.9836999773979187},\n", " {'label': 'LABEL_1', 'score': 0.9613105058670044},\n", " {'label': 'LABEL_0', 'score': 0.9625087380409241},\n", " {'label': 'LABEL_0', 'score': 0.9945961833000183},\n", " {'label': 'LABEL_1', 'score': 0.9716067910194397},\n", " {'label': 'LABEL_0', 'score': 0.9654806852340698},\n", " {'label': 'LABEL_0', 'score': 0.9887083768844604},\n", " {'label': 'LABEL_1', 'score': 0.9224319458007812},\n", " {'label': 'LABEL_1', 'score': 0.9960983991622925},\n", " {'label': 'LABEL_0', 'score': 0.9977193474769592},\n", " {'label': 'LABEL_1', 'score': 0.9953898191452026},\n", " {'label': 'LABEL_1', 'score': 0.9955762624740601},\n", " {'label': 'LABEL_0', 'score': 0.9048295021057129},\n", " {'label': 'LABEL_1', 'score': 0.9224010109901428},\n", " {'label': 'LABEL_1', 'score': 0.9839083552360535},\n", " {'label': 'LABEL_1', 'score': 0.7599542737007141},\n", " {'label': 'LABEL_1', 'score': 0.9917185306549072},\n", " {'label': 'LABEL_0', 'score': 0.8768678307533264},\n", " {'label': 'LABEL_1', 'score': 0.9816514253616333},\n", " {'label': 'LABEL_0', 'score': 0.5276628732681274},\n", " {'label': 'LABEL_1', 'score': 0.9766970872879028},\n", " {'label': 'LABEL_0', 'score': 0.9330600500106812},\n", " {'label': 'LABEL_0', 'score': 0.9660871624946594},\n", " {'label': 'LABEL_1', 'score': 0.9845715761184692},\n", " {'label': 'LABEL_1', 'score': 0.9950500130653381},\n", " {'label': 'LABEL_1', 'score': 0.996354341506958},\n", " {'label': 'LABEL_1', 'score': 0.9961453676223755},\n", " {'label': 'LABEL_0', 'score': 0.9704436659812927},\n", " {'label': 'LABEL_0', 'score': 0.8818399906158447},\n", " {'label': 'LABEL_1', 'score': 0.9842308163642883},\n", " {'label': 'LABEL_1', 'score': 0.6171972751617432},\n", " {'label': 'LABEL_0', 'score': 0.9949376583099365},\n", " {'label': 'LABEL_0', 'score': 0.9831938743591309},\n", " {'label': 'LABEL_0', 'score': 0.993530809879303},\n", " {'label': 'LABEL_1', 'score': 0.9809488654136658},\n", " {'label': 'LABEL_1', 'score': 0.9838874936103821},\n", " {'label': 'LABEL_0', 'score': 0.9927709698677063},\n", " {'label': 'LABEL_1', 'score': 0.672204315662384},\n", " {'label': 'LABEL_0', 'score': 0.9932838678359985},\n", " {'label': 'LABEL_1', 'score': 0.9767532348632812},\n", " {'label': 'LABEL_0', 'score': 0.9813488125801086},\n", " {'label': 'LABEL_1', 'score': 0.994283139705658},\n", " {'label': 'LABEL_1', 'score': 0.6814839243888855},\n", " {'label': 'LABEL_0', 'score': 0.9924652576446533},\n", " {'label': 'LABEL_1', 'score': 0.7334816455841064},\n", " {'label': 'LABEL_0', 'score': 0.9909833669662476},\n", " {'label': 'LABEL_0', 'score': 0.9856436252593994},\n", " {'label': 'LABEL_0', 'score': 0.9319406151771545},\n", " {'label': 'LABEL_0', 'score': 0.8773038983345032},\n", " {'label': 'LABEL_1', 'score': 0.9863520860671997},\n", " {'label': 'LABEL_1', 'score': 0.9160218834877014},\n", " {'label': 'LABEL_1', 'score': 0.9942693710327148},\n", " {'label': 'LABEL_1', 'score': 0.7459067106246948},\n", " {'label': 'LABEL_0', 'score': 0.9660363793373108},\n", " {'label': 'LABEL_0', 'score': 0.8596163988113403},\n", " {'label': 'LABEL_0', 'score': 0.9941113591194153},\n", " {'label': 'LABEL_0', 'score': 0.9466278553009033},\n", " {'label': 'LABEL_0', 'score': 0.873200535774231},\n", " {'label': 'LABEL_1', 'score': 0.9967073202133179},\n", " {'label': 'LABEL_1', 'score': 0.8866151571273804},\n", " {'label': 'LABEL_1', 'score': 0.9727884531021118},\n", " {'label': 'LABEL_0', 'score': 0.9961073994636536},\n", " {'label': 'LABEL_1', 'score': 0.9898661971092224},\n", " {'label': 'LABEL_1', 'score': 0.9957423806190491},\n", " {'label': 'LABEL_1', 'score': 0.9925336837768555},\n", " {'label': 'LABEL_1', 'score': 0.9566472172737122},\n", " {'label': 'LABEL_0', 'score': 0.9743568301200867},\n", " {'label': 'LABEL_1', 'score': 0.6052708029747009},\n", " {'label': 'LABEL_0', 'score': 0.963996946811676},\n", " {'label': 'LABEL_1', 'score': 0.9576812982559204},\n", " {'label': 'LABEL_1', 'score': 0.8900694251060486},\n", " {'label': 'LABEL_0', 'score': 0.9907097816467285},\n", " {'label': 'LABEL_1', 'score': 0.9793062210083008},\n", " {'label': 'LABEL_0', 'score': 0.9792136549949646},\n", " {'label': 'LABEL_0', 'score': 0.9821605682373047},\n", " {'label': 'LABEL_1', 'score': 0.990317702293396},\n", " {'label': 'LABEL_0', 'score': 0.9872243404388428},\n", " {'label': 'LABEL_1', 'score': 0.964670717716217},\n", " {'label': 'LABEL_0', 'score': 0.8030292391777039},\n", " {'label': 'LABEL_1', 'score': 0.6403388381004333},\n", " {'label': 'LABEL_0', 'score': 0.9047842621803284},\n", " {'label': 'LABEL_1', 'score': 0.9462425708770752},\n", " {'label': 'LABEL_0', 'score': 0.9106733798980713},\n", " {'label': 'LABEL_1', 'score': 0.9901905059814453},\n", " {'label': 'LABEL_1', 'score': 0.9941834807395935},\n", " {'label': 'LABEL_1', 'score': 0.9732745885848999},\n", " {'label': 'LABEL_0', 'score': 0.9958937168121338},\n", " {'label': 'LABEL_0', 'score': 0.6962946057319641},\n", " {'label': 'LABEL_1', 'score': 0.9949690699577332},\n", " {'label': 'LABEL_1', 'score': 0.9963836669921875},\n", " {'label': 'LABEL_1', 'score': 0.9950699806213379},\n", " {'label': 'LABEL_1', 'score': 0.9299766421318054},\n", " {'label': 'LABEL_1', 'score': 0.9953497648239136},\n", " {'label': 'LABEL_1', 'score': 0.9944887161254883},\n", " {'label': 'LABEL_1', 'score': 0.9949355721473694},\n", " {'label': 'LABEL_1', 'score': 0.9922817349433899},\n", " {'label': 'LABEL_0', 'score': 0.9634330868721008},\n", " {'label': 'LABEL_1', 'score': 0.9927870631217957},\n", " {'label': 'LABEL_1', 'score': 0.9881426692008972},\n", " {'label': 'LABEL_0', 'score': 0.9957276582717896},\n", " {'label': 'LABEL_0', 'score': 0.9653200507164001},\n", " {'label': 'LABEL_0', 'score': 0.9674235582351685},\n", " {'label': 'LABEL_0', 'score': 0.7343799471855164},\n", " {'label': 'LABEL_1', 'score': 0.992625892162323},\n", " {'label': 'LABEL_1', 'score': 0.9907028675079346},\n", " {'label': 'LABEL_0', 'score': 0.9948910474777222},\n", " {'label': 'LABEL_1', 'score': 0.7471431493759155},\n", " {'label': 'LABEL_1', 'score': 0.9757040739059448},\n", " {'label': 'LABEL_1', 'score': 0.9639447927474976},\n", " {'label': 'LABEL_1', 'score': 0.9619860649108887},\n", " {'label': 'LABEL_1', 'score': 0.9964457154273987},\n", " {'label': 'LABEL_1', 'score': 0.9969297051429749},\n", " {'label': 'LABEL_1', 'score': 0.9627381563186646},\n", " {'label': 'LABEL_1', 'score': 0.6666115522384644},\n", " {'label': 'LABEL_0', 'score': 0.9931562542915344},\n", " {'label': 'LABEL_1', 'score': 0.9928695559501648},\n", " {'label': 'LABEL_1', 'score': 0.9817350506782532},\n", " {'label': 'LABEL_1', 'score': 0.9804060459136963},\n", " {'label': 'LABEL_1', 'score': 0.9960583448410034},\n", " {'label': 'LABEL_1', 'score': 0.9965972304344177},\n", " {'label': 'LABEL_1', 'score': 0.9945698380470276},\n", " {'label': 'LABEL_1', 'score': 0.9966057538986206},\n", " {'label': 'LABEL_0', 'score': 0.9204991459846497},\n", " {'label': 'LABEL_1', 'score': 0.99128657579422},\n", " {'label': 'LABEL_0', 'score': 0.9790450930595398},\n", " {'label': 'LABEL_0', 'score': 0.9350587129592896},\n", " {'label': 'LABEL_0', 'score': 0.9830670952796936},\n", " {'label': 'LABEL_0', 'score': 0.9946940541267395},\n", " {'label': 'LABEL_0', 'score': 0.9659106135368347},\n", " {'label': 'LABEL_1', 'score': 0.9770876169204712},\n", " {'label': 'LABEL_0', 'score': 0.9846975803375244},\n", " {'label': 'LABEL_1', 'score': 0.5085996389389038},\n", " {'label': 'LABEL_1', 'score': 0.9591556191444397},\n", " {'label': 'LABEL_1', 'score': 0.6862562298774719},\n", " {'label': 'LABEL_0', 'score': 0.9033928513526917},\n", " {'label': 'LABEL_0', 'score': 0.9760305881500244},\n", " {'label': 'LABEL_1', 'score': 0.8040069937705994},\n", " {'label': 'LABEL_1', 'score': 0.9912001490592957},\n", " {'label': 'LABEL_0', 'score': 0.9618396759033203},\n", " {'label': 'LABEL_1', 'score': 0.6953979134559631},\n", " {'label': 'LABEL_1', 'score': 0.9941722750663757},\n", " {'label': 'LABEL_1', 'score': 0.9810013175010681},\n", " {'label': 'LABEL_0', 'score': 0.9971204996109009},\n", " {'label': 'LABEL_0', 'score': 0.770469605922699},\n", " {'label': 'LABEL_0', 'score': 0.9928150177001953},\n", " {'label': 'LABEL_1', 'score': 0.9852375388145447},\n", " {'label': 'LABEL_0', 'score': 0.9729052782058716},\n", " {'label': 'LABEL_1', 'score': 0.9801779389381409},\n", " {'label': 'LABEL_1', 'score': 0.9669545292854309},\n", " {'label': 'LABEL_0', 'score': 0.9891452193260193},\n", " {'label': 'LABEL_1', 'score': 0.9948595762252808},\n", " {'label': 'LABEL_1', 'score': 0.9631994366645813},\n", " {'label': 'LABEL_1', 'score': 0.9868159890174866},\n", " {'label': 'LABEL_0', 'score': 0.9939367175102234},\n", " {'label': 'LABEL_1', 'score': 0.9374150037765503},\n", " {'label': 'LABEL_0', 'score': 0.9290514588356018},\n", " {'label': 'LABEL_1', 'score': 0.9712733030319214},\n", " {'label': 'LABEL_1', 'score': 0.6352583169937134},\n", " {'label': 'LABEL_1', 'score': 0.9746260046958923},\n", " {'label': 'LABEL_1', 'score': 0.9918363690376282},\n", " {'label': 'LABEL_0', 'score': 0.6233887076377869},\n", " {'label': 'LABEL_0', 'score': 0.9778889417648315},\n", " {'label': 'LABEL_1', 'score': 0.6556466221809387},\n", " {'label': 'LABEL_1', 'score': 0.6855087280273438},\n", " {'label': 'LABEL_0', 'score': 0.9843446612358093},\n", " {'label': 'LABEL_0', 'score': 0.6279629468917847},\n", " {'label': 'LABEL_0', 'score': 0.9902299046516418},\n", " {'label': 'LABEL_1', 'score': 0.9677106142044067},\n", " {'label': 'LABEL_1', 'score': 0.95401930809021},\n", " {'label': 'LABEL_1', 'score': 0.8383669853210449},\n", " {'label': 'LABEL_0', 'score': 0.9565296173095703},\n", " {'label': 'LABEL_0', 'score': 0.9821329712867737},\n", " {'label': 'LABEL_1', 'score': 0.8867975473403931},\n", " {'label': 'LABEL_1', 'score': 0.9596521258354187},\n", " {'label': 'LABEL_1', 'score': 0.9953678846359253},\n", " {'label': 'LABEL_1', 'score': 0.6180222630500793},\n", " {'label': 'LABEL_1', 'score': 0.839966893196106},\n", " {'label': 'LABEL_1', 'score': 0.9966834187507629},\n", " {'label': 'LABEL_0', 'score': 0.988347589969635},\n", " {'label': 'LABEL_1', 'score': 0.9169568419456482},\n", " {'label': 'LABEL_0', 'score': 0.9921265840530396},\n", " {'label': 'LABEL_0', 'score': 0.8399364352226257},\n", " {'label': 'LABEL_0', 'score': 0.9207215905189514},\n", " {'label': 'LABEL_0', 'score': 0.9106295108795166},\n", " {'label': 'LABEL_0', 'score': 0.9921837449073792},\n", " {'label': 'LABEL_1', 'score': 0.9959458708763123},\n", " {'label': 'LABEL_0', 'score': 0.9894911050796509},\n", " {'label': 'LABEL_1', 'score': 0.9743772745132446},\n", " {'label': 'LABEL_0', 'score': 0.9118412733078003},\n", " {'label': 'LABEL_0', 'score': 0.9683268666267395},\n", " {'label': 'LABEL_0', 'score': 0.9877650141716003},\n", " {'label': 'LABEL_1', 'score': 0.8257834315299988},\n", " {'label': 'LABEL_1', 'score': 0.9933189153671265},\n", " {'label': 'LABEL_1', 'score': 0.9968757629394531},\n", " {'label': 'LABEL_1', 'score': 0.9589669108390808},\n", " {'label': 'LABEL_0', 'score': 0.9949743747711182},\n", " {'label': 'LABEL_0', 'score': 0.9604752063751221},\n", " {'label': 'LABEL_0', 'score': 0.6470544934272766},\n", " {'label': 'LABEL_0', 'score': 0.6516719460487366},\n", " {'label': 'LABEL_0', 'score': 0.8422739505767822},\n", " {'label': 'LABEL_0', 'score': 0.9935320615768433},\n", " {'label': 'LABEL_1', 'score': 0.9534354209899902},\n", " {'label': 'LABEL_1', 'score': 0.9863380789756775},\n", " {'label': 'LABEL_1', 'score': 0.9699581265449524},\n", " {'label': 'LABEL_1', 'score': 0.5838924050331116},\n", " {'label': 'LABEL_0', 'score': 0.9731244444847107},\n", " {'label': 'LABEL_0', 'score': 0.8859334588050842},\n", " {'label': 'LABEL_1', 'score': 0.9310123324394226},\n", " {'label': 'LABEL_0', 'score': 0.9858405590057373},\n", " {'label': 'LABEL_1', 'score': 0.8960093259811401},\n", " {'label': 'LABEL_0', 'score': 0.6809247136116028},\n", " {'label': 'LABEL_1', 'score': 0.9651084542274475},\n", " {'label': 'LABEL_1', 'score': 0.9708428978919983},\n", " {'label': 'LABEL_1', 'score': 0.8113129138946533},\n", " {'label': 'LABEL_1', 'score': 0.9679713249206543},\n", " {'label': 'LABEL_1', 'score': 0.9831263422966003},\n", " {'label': 'LABEL_1', 'score': 0.9862723350524902},\n", " {'label': 'LABEL_1', 'score': 0.9903726577758789},\n", " {'label': 'LABEL_0', 'score': 0.9874128103256226},\n", " {'label': 'LABEL_0', 'score': 0.9676929712295532},\n", " {'label': 'LABEL_0', 'score': 0.9208645224571228},\n", " {'label': 'LABEL_1', 'score': 0.7963953018188477},\n", " {'label': 'LABEL_1', 'score': 0.9790200591087341},\n", " {'label': 'LABEL_0', 'score': 0.9700548052787781},\n", " {'label': 'LABEL_1', 'score': 0.9285778403282166},\n", " {'label': 'LABEL_0', 'score': 0.9298763275146484},\n", " {'label': 'LABEL_0', 'score': 0.9788060784339905},\n", " {'label': 'LABEL_0', 'score': 0.5070981979370117},\n", " {'label': 'LABEL_0', 'score': 0.6101786494255066},\n", " {'label': 'LABEL_1', 'score': 0.9401609897613525},\n", " {'label': 'LABEL_0', 'score': 0.5705845952033997},\n", " {'label': 'LABEL_0', 'score': 0.715858519077301},\n", " {'label': 'LABEL_1', 'score': 0.9841576218605042},\n", " {'label': 'LABEL_0', 'score': 0.9133334755897522},\n", " {'label': 'LABEL_1', 'score': 0.8897599577903748},\n", " {'label': 'LABEL_0', 'score': 0.9941587448120117},\n", " {'label': 'LABEL_1', 'score': 0.9589704275131226},\n", " {'label': 'LABEL_1', 'score': 0.9959990978240967},\n", " {'label': 'LABEL_0', 'score': 0.9623426198959351},\n", " {'label': 'LABEL_1', 'score': 0.9778594374656677},\n", " {'label': 'LABEL_1', 'score': 0.9872311949729919},\n", " {'label': 'LABEL_1', 'score': 0.980732262134552},\n", " {'label': 'LABEL_1', 'score': 0.9541248083114624},\n", " {'label': 'LABEL_1', 'score': 0.9724259376525879},\n", " {'label': 'LABEL_1', 'score': 0.9902752637863159},\n", " {'label': 'LABEL_1', 'score': 0.9823142290115356},\n", " {'label': 'LABEL_1', 'score': 0.9673748016357422},\n", " {'label': 'LABEL_1', 'score': 0.9039739370346069},\n", " {'label': 'LABEL_0', 'score': 0.9813490509986877},\n", " {'label': 'LABEL_0', 'score': 0.9910780191421509},\n", " {'label': 'LABEL_1', 'score': 0.7190228700637817},\n", " {'label': 'LABEL_0', 'score': 0.955741822719574},\n", " {'label': 'LABEL_1', 'score': 0.9780182242393494},\n", " {'label': 'LABEL_0', 'score': 0.9955571293830872},\n", " {'label': 'LABEL_0', 'score': 0.9368971586227417},\n", " {'label': 'LABEL_0', 'score': 0.9864141345024109},\n", " {'label': 'LABEL_0', 'score': 0.9919256567955017},\n", " {'label': 'LABEL_1', 'score': 0.9556246399879456},\n", " {'label': 'LABEL_0', 'score': 0.9608051180839539},\n", " {'label': 'LABEL_0', 'score': 0.990741491317749},\n", " {'label': 'LABEL_0', 'score': 0.9546958804130554},\n", " {'label': 'LABEL_0', 'score': 0.9897760152816772},\n", " {'label': 'LABEL_0', 'score': 0.9305821657180786},\n", " {'label': 'LABEL_1', 'score': 0.9401752352714539},\n", " {'label': 'LABEL_0', 'score': 0.8841428756713867},\n", " {'label': 'LABEL_0', 'score': 0.9952380657196045},\n", " {'label': 'LABEL_0', 'score': 0.9911361932754517},\n", " {'label': 'LABEL_0', 'score': 0.9772043824195862},\n", " {'label': 'LABEL_0', 'score': 0.765044093132019},\n", " {'label': 'LABEL_0', 'score': 0.8548526763916016},\n", " {'label': 'LABEL_1', 'score': 0.8744843006134033},\n", " {'label': 'LABEL_0', 'score': 0.9837722182273865},\n", " {'label': 'LABEL_0', 'score': 0.9745046496391296},\n", " {'label': 'LABEL_0', 'score': 0.9850363731384277},\n", " {'label': 'LABEL_0', 'score': 0.9176458120346069},\n", " {'label': 'LABEL_0', 'score': 0.9215735197067261},\n", " {'label': 'LABEL_1', 'score': 0.6565162539482117},\n", " {'label': 'LABEL_0', 'score': 0.98956698179245},\n", " {'label': 'LABEL_0', 'score': 0.8922049403190613},\n", " {'label': 'LABEL_0', 'score': 0.9962350726127625},\n", " {'label': 'LABEL_0', 'score': 0.9958523511886597},\n", " {'label': 'LABEL_0', 'score': 0.976225733757019},\n", " {'label': 'LABEL_0', 'score': 0.9917077422142029},\n", " {'label': 'LABEL_0', 'score': 0.9804897904396057},\n", " {'label': 'LABEL_1', 'score': 0.8051390051841736},\n", " {'label': 'LABEL_1', 'score': 0.9766478538513184},\n", " {'label': 'LABEL_0', 'score': 0.9742063283920288},\n", " {'label': 'LABEL_0', 'score': 0.9911614656448364},\n", " {'label': 'LABEL_1', 'score': 0.6183955073356628},\n", " {'label': 'LABEL_0', 'score': 0.9954544305801392},\n", " {'label': 'LABEL_0', 'score': 0.9946867227554321},\n", " {'label': 'LABEL_0', 'score': 0.6935546398162842},\n", " {'label': 'LABEL_1', 'score': 0.8637236952781677},\n", " {'label': 'LABEL_1', 'score': 0.9793212413787842},\n", " {'label': 'LABEL_1', 'score': 0.9931596517562866},\n", " {'label': 'LABEL_0', 'score': 0.9968294501304626},\n", " {'label': 'LABEL_1', 'score': 0.9714514017105103},\n", " {'label': 'LABEL_0', 'score': 0.9883481860160828},\n", " {'label': 'LABEL_0', 'score': 0.8043408989906311},\n", " {'label': 'LABEL_0', 'score': 0.9928232431411743},\n", " {'label': 'LABEL_0', 'score': 0.9866151809692383},\n", " {'label': 'LABEL_0', 'score': 0.9845389723777771},\n", " {'label': 'LABEL_0', 'score': 0.8769729733467102},\n", " {'label': 'LABEL_0', 'score': 0.8691015243530273},\n", " {'label': 'LABEL_0', 'score': 0.8536194562911987},\n", " {'label': 'LABEL_0', 'score': 0.9304572343826294},\n", " {'label': 'LABEL_0', 'score': 0.9983558058738708},\n", " {'label': 'LABEL_0', 'score': 0.9976704716682434},\n", " {'label': 'LABEL_0', 'score': 0.9949118494987488},\n", " {'label': 'LABEL_1', 'score': 0.9864810705184937},\n", " {'label': 'LABEL_1', 'score': 0.6623121500015259},\n", " {'label': 'LABEL_0', 'score': 0.9275172352790833},\n", " {'label': 'LABEL_1', 'score': 0.9489859938621521},\n", " {'label': 'LABEL_1', 'score': 0.9946613907814026},\n", " {'label': 'LABEL_0', 'score': 0.9962383508682251},\n", " {'label': 'LABEL_0', 'score': 0.6611701250076294},\n", " {'label': 'LABEL_0', 'score': 0.9595711827278137},\n", " {'label': 'LABEL_0', 'score': 0.7789456844329834},\n", " {'label': 'LABEL_0', 'score': 0.98863285779953},\n", " {'label': 'LABEL_0', 'score': 0.8365599513053894},\n", " {'label': 'LABEL_0', 'score': 0.9642247557640076},\n", " {'label': 'LABEL_0', 'score': 0.9932066202163696},\n", " {'label': 'LABEL_0', 'score': 0.6040222644805908},\n", " {'label': 'LABEL_0', 'score': 0.9800452589988708},\n", " {'label': 'LABEL_1', 'score': 0.8107476234436035},\n", " {'label': 'LABEL_0', 'score': 0.9928464293479919},\n", " {'label': 'LABEL_0', 'score': 0.978429913520813},\n", " {'label': 'LABEL_0', 'score': 0.9875631332397461},\n", " {'label': 'LABEL_1', 'score': 0.9946447610855103},\n", " {'label': 'LABEL_0', 'score': 0.9569626450538635},\n", " {'label': 'LABEL_0', 'score': 0.5825687050819397},\n", " {'label': 'LABEL_1', 'score': 0.9850578904151917},\n", " {'label': 'LABEL_0', 'score': 0.9718114137649536},\n", " {'label': 'LABEL_1', 'score': 0.9784877896308899},\n", " {'label': 'LABEL_0', 'score': 0.9373846054077148},\n", " {'label': 'LABEL_1', 'score': 0.9798907041549683},\n", " {'label': 'LABEL_1', 'score': 0.8990122675895691},\n", " {'label': 'LABEL_1', 'score': 0.966254472732544},\n", " {'label': 'LABEL_0', 'score': 0.9869864583015442},\n", " {'label': 'LABEL_0', 'score': 0.9455236196517944},\n", " {'label': 'LABEL_1', 'score': 0.9763216972351074},\n", " {'label': 'LABEL_0', 'score': 0.9844340682029724},\n", " {'label': 'LABEL_0', 'score': 0.9957629442214966},\n", " {'label': 'LABEL_0', 'score': 0.9980828762054443},\n", " {'label': 'LABEL_0', 'score': 0.9765481352806091},\n", " {'label': 'LABEL_0', 'score': 0.9793199896812439},\n", " {'label': 'LABEL_0', 'score': 0.9791136384010315},\n", " {'label': 'LABEL_0', 'score': 0.9887147545814514},\n", " {'label': 'LABEL_1', 'score': 0.9902543425559998},\n", " {'label': 'LABEL_0', 'score': 0.9907160997390747},\n", " {'label': 'LABEL_0', 'score': 0.9892561435699463},\n", " {'label': 'LABEL_0', 'score': 0.99406498670578},\n", " {'label': 'LABEL_0', 'score': 0.9581699371337891},\n", " {'label': 'LABEL_1', 'score': 0.5796918869018555},\n", " {'label': 'LABEL_0', 'score': 0.9482484459877014},\n", " {'label': 'LABEL_1', 'score': 0.9893319606781006},\n", " {'label': 'LABEL_0', 'score': 0.8491203784942627},\n", " {'label': 'LABEL_0', 'score': 0.9950813055038452},\n", " {'label': 'LABEL_0', 'score': 0.9937944412231445},\n", " {'label': 'LABEL_0', 'score': 0.9934796094894409},\n", " {'label': 'LABEL_0', 'score': 0.9109570384025574},\n", " {'label': 'LABEL_1', 'score': 0.9654124975204468},\n", " {'label': 'LABEL_0', 'score': 0.9397651553153992},\n", " {'label': 'LABEL_1', 'score': 0.8101767897605896},\n", " {'label': 'LABEL_0', 'score': 0.9851425290107727},\n", " {'label': 'LABEL_0', 'score': 0.8464933633804321},\n", " {'label': 'LABEL_0', 'score': 0.7086097598075867},\n", " {'label': 'LABEL_1', 'score': 0.9942412376403809},\n", " {'label': 'LABEL_0', 'score': 0.9564889669418335},\n", " {'label': 'LABEL_0', 'score': 0.6702427268028259},\n", " {'label': 'LABEL_0', 'score': 0.9953863024711609},\n", " {'label': 'LABEL_0', 'score': 0.9961366057395935},\n", " {'label': 'LABEL_0', 'score': 0.9579523801803589},\n", " {'label': 'LABEL_0', 'score': 0.9929841756820679},\n", " {'label': 'LABEL_1', 'score': 0.8427727818489075},\n", " {'label': 'LABEL_0', 'score': 0.988649845123291},\n", " {'label': 'LABEL_0', 'score': 0.9902087450027466},\n", " {'label': 'LABEL_0', 'score': 0.992279052734375},\n", " {'label': 'LABEL_0', 'score': 0.9860447645187378},\n", " {'label': 'LABEL_0', 'score': 0.7342256307601929},\n", " {'label': 'LABEL_1', 'score': 0.627812922000885},\n", " {'label': 'LABEL_1', 'score': 0.9554335474967957},\n", " {'label': 'LABEL_0', 'score': 0.9830189347267151},\n", " {'label': 'LABEL_0', 'score': 0.861741304397583},\n", " {'label': 'LABEL_1', 'score': 0.9945268034934998},\n", " {'label': 'LABEL_0', 'score': 0.9706934094429016},\n", " {'label': 'LABEL_0', 'score': 0.9862181544303894},\n", " {'label': 'LABEL_1', 'score': 0.6970356702804565},\n", " {'label': 'LABEL_1', 'score': 0.9795743823051453},\n", " {'label': 'LABEL_1', 'score': 0.8964106440544128},\n", " {'label': 'LABEL_1', 'score': 0.990230143070221},\n", " {'label': 'LABEL_0', 'score': 0.9843607544898987},\n", " {'label': 'LABEL_1', 'score': 0.9737773537635803},\n", " {'label': 'LABEL_1', 'score': 0.9561145901679993},\n", " {'label': 'LABEL_1', 'score': 0.7726802229881287},\n", " {'label': 'LABEL_0', 'score': 0.9867532849311829},\n", " {'label': 'LABEL_1', 'score': 0.9936423301696777},\n", " {'label': 'LABEL_1', 'score': 0.8904270529747009},\n", " {'label': 'LABEL_0', 'score': 0.8102100491523743},\n", " {'label': 'LABEL_1', 'score': 0.7072275876998901},\n", " {'label': 'LABEL_1', 'score': 0.9506065845489502},\n", " {'label': 'LABEL_1', 'score': 0.6668100357055664},\n", " {'label': 'LABEL_0', 'score': 0.9742982983589172},\n", " {'label': 'LABEL_1', 'score': 0.8297302722930908},\n", " {'label': 'LABEL_1', 'score': 0.976436972618103},\n", " {'label': 'LABEL_0', 'score': 0.965576171875},\n", " {'label': 'LABEL_1', 'score': 0.6782581806182861},\n", " {'label': 'LABEL_0', 'score': 0.9925404191017151},\n", " {'label': 'LABEL_1', 'score': 0.552829921245575},\n", " {'label': 'LABEL_1', 'score': 0.9796808958053589},\n", " {'label': 'LABEL_0', 'score': 0.8962493538856506},\n", " {'label': 'LABEL_0', 'score': 0.9819111227989197},\n", " {'label': 'LABEL_1', 'score': 0.9811175465583801},\n", " {'label': 'LABEL_0', 'score': 0.9360445737838745},\n", " {'label': 'LABEL_1', 'score': 0.9926239252090454},\n", " {'label': 'LABEL_1', 'score': 0.9821844100952148},\n", " {'label': 'LABEL_0', 'score': 0.9439947009086609},\n", " {'label': 'LABEL_1', 'score': 0.9955853223800659},\n", " {'label': 'LABEL_1', 'score': 0.9958295226097107},\n", " {'label': 'LABEL_1', 'score': 0.6333702802658081},\n", " {'label': 'LABEL_0', 'score': 0.9083454012870789},\n", " {'label': 'LABEL_0', 'score': 0.8881974220275879},\n", " {'label': 'LABEL_1', 'score': 0.6101353168487549},\n", " {'label': 'LABEL_0', 'score': 0.957714855670929},\n", " {'label': 'LABEL_0', 'score': 0.9776718020439148},\n", " {'label': 'LABEL_0', 'score': 0.5999106168746948},\n", " {'label': 'LABEL_0', 'score': 0.989844560623169},\n", " {'label': 'LABEL_1', 'score': 0.9848566651344299},\n", " {'label': 'LABEL_1', 'score': 0.7018373012542725},\n", " {'label': 'LABEL_1', 'score': 0.9768227338790894},\n", " {'label': 'LABEL_1', 'score': 0.9928419589996338},\n", " {'label': 'LABEL_1', 'score': 0.9943158030509949},\n", " {'label': 'LABEL_0', 'score': 0.8971171379089355},\n", " {'label': 'LABEL_0', 'score': 0.9800693988800049},\n", " {'label': 'LABEL_1', 'score': 0.9963098168373108},\n", " {'label': 'LABEL_1', 'score': 0.9947733283042908},\n", " {'label': 'LABEL_0', 'score': 0.9846504926681519},\n", " {'label': 'LABEL_0', 'score': 0.9964327812194824},\n", " {'label': 'LABEL_0', 'score': 0.9576809406280518},\n", " {'label': 'LABEL_1', 'score': 0.9294152855873108},\n", " {'label': 'LABEL_0', 'score': 0.7247462272644043},\n", " {'label': 'LABEL_0', 'score': 0.9721958041191101},\n", " {'label': 'LABEL_1', 'score': 0.9864373803138733},\n", " {'label': 'LABEL_1', 'score': 0.9966200590133667},\n", " {'label': 'LABEL_1', 'score': 0.9962632060050964},\n", " {'label': 'LABEL_1', 'score': 0.5823416113853455},\n", " {'label': 'LABEL_1', 'score': 0.9945600628852844},\n", " {'label': 'LABEL_1', 'score': 0.9933189153671265},\n", " {'label': 'LABEL_1', 'score': 0.9757441282272339},\n", " {'label': 'LABEL_1', 'score': 0.9938511848449707},\n", " {'label': 'LABEL_1', 'score': 0.9950587749481201},\n", " {'label': 'LABEL_1', 'score': 0.9951475262641907},\n", " {'label': 'LABEL_1', 'score': 0.9965439438819885},\n", " {'label': 'LABEL_1', 'score': 0.756714940071106},\n", " {'label': 'LABEL_0', 'score': 0.9147056341171265},\n", " {'label': 'LABEL_1', 'score': 0.9834950566291809},\n", " {'label': 'LABEL_1', 'score': 0.9941120743751526},\n", " {'label': 'LABEL_1', 'score': 0.9966141581535339},\n", " {'label': 'LABEL_0', 'score': 0.9373542070388794},\n", " {'label': 'LABEL_1', 'score': 0.9953212141990662},\n", " {'label': 'LABEL_1', 'score': 0.9967867136001587},\n", " {'label': 'LABEL_1', 'score': 0.9939202070236206},\n", " {'label': 'LABEL_1', 'score': 0.7412394285202026},\n", " {'label': 'LABEL_0', 'score': 0.9623532295227051},\n", " {'label': 'LABEL_1', 'score': 0.9957146048545837},\n", " {'label': 'LABEL_1', 'score': 0.9939712882041931},\n", " {'label': 'LABEL_1', 'score': 0.9923322200775146},\n", " {'label': 'LABEL_1', 'score': 0.9660599827766418},\n", " {'label': 'LABEL_1', 'score': 0.9844948649406433},\n", " {'label': 'LABEL_1', 'score': 0.7453780174255371},\n", " {'label': 'LABEL_1', 'score': 0.9607815742492676},\n", " {'label': 'LABEL_1', 'score': 0.9240431785583496},\n", " {'label': 'LABEL_1', 'score': 0.9936596751213074},\n", " {'label': 'LABEL_1', 'score': 0.5215803384780884},\n", " {'label': 'LABEL_1', 'score': 0.9923531413078308},\n", " {'label': 'LABEL_1', 'score': 0.996246874332428},\n", " {'label': 'LABEL_1', 'score': 0.996537446975708},\n", " {'label': 'LABEL_1', 'score': 0.9740337133407593},\n", " {'label': 'LABEL_1', 'score': 0.9953761100769043},\n", " {'label': 'LABEL_1', 'score': 0.9968298077583313},\n", " {'label': 'LABEL_1', 'score': 0.9939036965370178},\n", " {'label': 'LABEL_1', 'score': 0.9916537404060364},\n", " {'label': 'LABEL_1', 'score': 0.9962269067764282},\n", " {'label': 'LABEL_1', 'score': 0.9958069324493408},\n", " {'label': 'LABEL_1', 'score': 0.9956639409065247},\n", " {'label': 'LABEL_0', 'score': 0.9886506795883179},\n", " {'label': 'LABEL_1', 'score': 0.9480882287025452},\n", " {'label': 'LABEL_1', 'score': 0.9821166396141052},\n", " {'label': 'LABEL_1', 'score': 0.9922047257423401},\n", " {'label': 'LABEL_1', 'score': 0.9872769713401794},\n", " {'label': 'LABEL_1', 'score': 0.991014301776886},\n", " {'label': 'LABEL_1', 'score': 0.9588004946708679},\n", " {'label': 'LABEL_0', 'score': 0.782096266746521},\n", " {'label': 'LABEL_1', 'score': 0.9705247282981873},\n", " {'label': 'LABEL_0', 'score': 0.9984493255615234},\n", " {'label': 'LABEL_1', 'score': 0.987313449382782},\n", " {'label': 'LABEL_1', 'score': 0.9935043454170227},\n", " {'label': 'LABEL_1', 'score': 0.9941995143890381},\n", " {'label': 'LABEL_1', 'score': 0.9954388737678528},\n", " {'label': 'LABEL_1', 'score': 0.9934248924255371},\n", " {'label': 'LABEL_1', 'score': 0.5153225660324097},\n", " {'label': 'LABEL_1', 'score': 0.9926658272743225},\n", " {'label': 'LABEL_1', 'score': 0.9900650382041931},\n", " {'label': 'LABEL_1', 'score': 0.990449845790863},\n", " {'label': 'LABEL_1', 'score': 0.9940094947814941},\n", " {'label': 'LABEL_1', 'score': 0.9649341702461243},\n", " {'label': 'LABEL_0', 'score': 0.9854132533073425},\n", " {'label': 'LABEL_0', 'score': 0.6902562975883484},\n", " {'label': 'LABEL_0', 'score': 0.9953790903091431},\n", " {'label': 'LABEL_1', 'score': 0.9951647520065308},\n", " {'label': 'LABEL_1', 'score': 0.9959852695465088},\n", " {'label': 'LABEL_1', 'score': 0.979231595993042},\n", " {'label': 'LABEL_0', 'score': 0.5359320640563965},\n", " {'label': 'LABEL_0', 'score': 0.9935789108276367},\n", " {'label': 'LABEL_1', 'score': 0.9128223061561584},\n", " {'label': 'LABEL_1', 'score': 0.9885913729667664},\n", " {'label': 'LABEL_0', 'score': 0.9898375272750854},\n", " {'label': 'LABEL_0', 'score': 0.9732145071029663},\n", " {'label': 'LABEL_1', 'score': 0.9945152401924133},\n", " {'label': 'LABEL_0', 'score': 0.9821017980575562},\n", " {'label': 'LABEL_0', 'score': 0.5703974962234497},\n", " {'label': 'LABEL_0', 'score': 0.9287216663360596},\n", " {'label': 'LABEL_1', 'score': 0.9972570538520813},\n", " {'label': 'LABEL_0', 'score': 0.8542720675468445},\n", " {'label': 'LABEL_1', 'score': 0.995238184928894},\n", " {'label': 'LABEL_1', 'score': 0.992724597454071},\n", " {'label': 'LABEL_1', 'score': 0.9938576817512512},\n", " {'label': 'LABEL_1', 'score': 0.9965469241142273},\n", " {'label': 'LABEL_0', 'score': 0.9916887879371643},\n", " {'label': 'LABEL_1', 'score': 0.9956098198890686},\n", " {'label': 'LABEL_1', 'score': 0.9948581457138062},\n", " {'label': 'LABEL_1', 'score': 0.9800918102264404},\n", " {'label': 'LABEL_0', 'score': 0.9850999712944031},\n", " {'label': 'LABEL_1', 'score': 0.9944234490394592},\n", " {'label': 'LABEL_1', 'score': 0.9922235608100891},\n", " {'label': 'LABEL_1', 'score': 0.9933009147644043},\n", " {'label': 'LABEL_1', 'score': 0.9839582443237305},\n", " {'label': 'LABEL_0', 'score': 0.8630751371383667},\n", " {'label': 'LABEL_0', 'score': 0.9115880131721497},\n", " {'label': 'LABEL_0', 'score': 0.9788351058959961},\n", " {'label': 'LABEL_1', 'score': 0.9253813624382019},\n", " {'label': 'LABEL_0', 'score': 0.9877164959907532},\n", " {'label': 'LABEL_1', 'score': 0.9453893303871155},\n", " {'label': 'LABEL_1', 'score': 0.946543276309967},\n", " {'label': 'LABEL_0', 'score': 0.8820732235908508},\n", " {'label': 'LABEL_1', 'score': 0.9947957396507263},\n", " {'label': 'LABEL_1', 'score': 0.9912304878234863},\n", " {'label': 'LABEL_1', 'score': 0.9177975654602051},\n", " {'label': 'LABEL_0', 'score': 0.9687201380729675},\n", " {'label': 'LABEL_1', 'score': 0.9947293400764465},\n", " {'label': 'LABEL_1', 'score': 0.970740556716919},\n", " {'label': 'LABEL_0', 'score': 0.8022850155830383},\n", " {'label': 'LABEL_0', 'score': 0.9579288959503174},\n", " {'label': 'LABEL_0', 'score': 0.9937421679496765},\n", " {'label': 'LABEL_1', 'score': 0.994292140007019},\n", " {'label': 'LABEL_1', 'score': 0.9968717694282532},\n", " {'label': 'LABEL_0', 'score': 0.9592112898826599},\n", " {'label': 'LABEL_0', 'score': 0.6592299342155457},\n", " {'label': 'LABEL_1', 'score': 0.9956451654434204},\n", " {'label': 'LABEL_1', 'score': 0.9911348819732666},\n", " {'label': 'LABEL_1', 'score': 0.9951386451721191},\n", " {'label': 'LABEL_0', 'score': 0.5056076049804688},\n", " {'label': 'LABEL_1', 'score': 0.8203694224357605},\n", " {'label': 'LABEL_1', 'score': 0.9595959186553955},\n", " {'label': 'LABEL_1', 'score': 0.9836644530296326},\n", " {'label': 'LABEL_1', 'score': 0.9942445158958435},\n", " {'label': 'LABEL_0', 'score': 0.991622269153595},\n", " {'label': 'LABEL_1', 'score': 0.9954179525375366},\n", " {'label': 'LABEL_0', 'score': 0.52507483959198},\n", " {'label': 'LABEL_0', 'score': 0.9515404105186462},\n", " {'label': 'LABEL_0', 'score': 0.9718201756477356},\n", " {'label': 'LABEL_0', 'score': 0.6818857192993164},\n", " {'label': 'LABEL_0', 'score': 0.9904976487159729},\n", " {'label': 'LABEL_1', 'score': 0.9965739250183105},\n", " {'label': 'LABEL_1', 'score': 0.6346263289451599},\n", " {'label': 'LABEL_1', 'score': 0.9970904588699341},\n", " {'label': 'LABEL_1', 'score': 0.9544928073883057},\n", " {'label': 'LABEL_1', 'score': 0.9378147721290588},\n", " {'label': 'LABEL_1', 'score': 0.9889683723449707},\n", " {'label': 'LABEL_1', 'score': 0.9771023988723755},\n", " {'label': 'LABEL_1', 'score': 0.8697611689567566},\n", " {'label': 'LABEL_1', 'score': 0.9960363507270813},\n", " {'label': 'LABEL_1', 'score': 0.9348815679550171},\n", " {'label': 'LABEL_1', 'score': 0.9703534841537476},\n", " {'label': 'LABEL_0', 'score': 0.9966253042221069},\n", " {'label': 'LABEL_1', 'score': 0.8299302458763123},\n", " {'label': 'LABEL_0', 'score': 0.9111083149909973},\n", " {'label': 'LABEL_1', 'score': 0.8398126363754272},\n", " {'label': 'LABEL_1', 'score': 0.9752505421638489},\n", " {'label': 'LABEL_1', 'score': 0.9928673505783081},\n", " {'label': 'LABEL_1', 'score': 0.9740899205207825},\n", " {'label': 'LABEL_1', 'score': 0.9677812457084656},\n", " {'label': 'LABEL_1', 'score': 0.96604984998703},\n", " {'label': 'LABEL_1', 'score': 0.9920864701271057},\n", " {'label': 'LABEL_1', 'score': 0.9934585690498352},\n", " {'label': 'LABEL_1', 'score': 0.991157054901123},\n", " {'label': 'LABEL_0', 'score': 0.9470826983451843},\n", " {'label': 'LABEL_1', 'score': 0.9873031973838806},\n", " {'label': 'LABEL_1', 'score': 0.9779534339904785},\n", " {'label': 'LABEL_1', 'score': 0.9912458062171936},\n", " {'label': 'LABEL_1', 'score': 0.9942110776901245},\n", " {'label': 'LABEL_1', 'score': 0.993436336517334},\n", " {'label': 'LABEL_1', 'score': 0.9949154853820801},\n", " {'label': 'LABEL_1', 'score': 0.9935457110404968},\n", " {'label': 'LABEL_1', 'score': 0.9942652583122253},\n", " {'label': 'LABEL_1', 'score': 0.9925038814544678},\n", " {'label': 'LABEL_1', 'score': 0.9942684173583984},\n", " {'label': 'LABEL_1', 'score': 0.8534121513366699},\n", " {'label': 'LABEL_1', 'score': 0.9949132204055786},\n", " {'label': 'LABEL_0', 'score': 0.9427997469902039},\n", " {'label': 'LABEL_1', 'score': 0.9734408259391785},\n", " {'label': 'LABEL_1', 'score': 0.9921699166297913},\n", " {'label': 'LABEL_1', 'score': 0.9898235201835632},\n", " {'label': 'LABEL_1', 'score': 0.7765575647354126},\n", " {'label': 'LABEL_1', 'score': 0.9793251156806946},\n", " {'label': 'LABEL_1', 'score': 0.8394061326980591},\n", " {'label': 'LABEL_1', 'score': 0.9897447228431702},\n", " {'label': 'LABEL_0', 'score': 0.9154225587844849},\n", " {'label': 'LABEL_0', 'score': 0.9888221025466919},\n", " {'label': 'LABEL_1', 'score': 0.9890047907829285},\n", " {'label': 'LABEL_1', 'score': 0.9717256426811218},\n", " {'label': 'LABEL_1', 'score': 0.9893767833709717},\n", " {'label': 'LABEL_1', 'score': 0.9924948215484619},\n", " {'label': 'LABEL_1', 'score': 0.9964051246643066},\n", " {'label': 'LABEL_1', 'score': 0.9560785293579102},\n", " {'label': 'LABEL_0', 'score': 0.7070202231407166},\n", " {'label': 'LABEL_1', 'score': 0.9044116735458374},\n", " {'label': 'LABEL_1', 'score': 0.9941132664680481},\n", " {'label': 'LABEL_1', 'score': 0.9947010278701782},\n", " {'label': 'LABEL_1', 'score': 0.9965517520904541},\n", " {'label': 'LABEL_1', 'score': 0.9773184657096863},\n", " {'label': 'LABEL_1', 'score': 0.9909886717796326},\n", " {'label': 'LABEL_1', 'score': 0.9947931170463562},\n", " {'label': 'LABEL_1', 'score': 0.9928774237632751},\n", " {'label': 'LABEL_1', 'score': 0.9927247762680054},\n", " {'label': 'LABEL_0', 'score': 0.9351725578308105},\n", " {'label': 'LABEL_1', 'score': 0.9962174296379089},\n", " {'label': 'LABEL_1', 'score': 0.9845442175865173},\n", " {'label': 'LABEL_0', 'score': 0.9229084849357605},\n", " {'label': 'LABEL_1', 'score': 0.9958617687225342},\n", " {'label': 'LABEL_0', 'score': 0.988365888595581},\n", " {'label': 'LABEL_0', 'score': 0.9961277842521667},\n", " {'label': 'LABEL_0', 'score': 0.995089590549469},\n", " {'label': 'LABEL_1', 'score': 0.9151782393455505},\n", " {'label': 'LABEL_1', 'score': 0.7796429991722107},\n", " {'label': 'LABEL_0', 'score': 0.6353229880332947},\n", " {'label': 'LABEL_1', 'score': 0.5075734853744507},\n", " {'label': 'LABEL_1', 'score': 0.9626388549804688},\n", " {'label': 'LABEL_1', 'score': 0.991780698299408},\n", " {'label': 'LABEL_1', 'score': 0.9936391711235046},\n", " {'label': 'LABEL_1', 'score': 0.9885025024414062},\n", " {'label': 'LABEL_1', 'score': 0.9949436783790588},\n", " {'label': 'LABEL_1', 'score': 0.9955767393112183},\n", " {'label': 'LABEL_1', 'score': 0.9393144845962524},\n", " {'label': 'LABEL_0', 'score': 0.8385283946990967},\n", " {'label': 'LABEL_0', 'score': 0.6774574518203735},\n", " {'label': 'LABEL_1', 'score': 0.9914259910583496},\n", " {'label': 'LABEL_1', 'score': 0.9934968948364258},\n", " {'label': 'LABEL_1', 'score': 0.8619126677513123},\n", " {'label': 'LABEL_1', 'score': 0.9947986602783203},\n", " {'label': 'LABEL_1', 'score': 0.9928877949714661},\n", " {'label': 'LABEL_1', 'score': 0.9960314631462097},\n", " {'label': 'LABEL_0', 'score': 0.8468936681747437},\n", " {'label': 'LABEL_1', 'score': 0.9907916784286499},\n", " {'label': 'LABEL_1', 'score': 0.692417562007904},\n", " {'label': 'LABEL_1', 'score': 0.9829657673835754},\n", " {'label': 'LABEL_0', 'score': 0.9849721789360046},\n", " {'label': 'LABEL_0', 'score': 0.9317151308059692},\n", " {'label': 'LABEL_1', 'score': 0.6366862654685974},\n", " {'label': 'LABEL_1', 'score': 0.9936919212341309},\n", " {'label': 'LABEL_1', 'score': 0.9850026965141296},\n", " {'label': 'LABEL_0', 'score': 0.933652400970459},\n", " {'label': 'LABEL_1', 'score': 0.9690437912940979},\n", " {'label': 'LABEL_0', 'score': 0.936476469039917},\n", " {'label': 'LABEL_1', 'score': 0.9966297745704651},\n", " {'label': 'LABEL_1', 'score': 0.9681738615036011},\n", " {'label': 'LABEL_0', 'score': 0.9943897128105164},\n", " {'label': 'LABEL_0', 'score': 0.8239411115646362},\n", " {'label': 'LABEL_0', 'score': 0.998040497303009},\n", " {'label': 'LABEL_1', 'score': 0.9954612851142883},\n", " {'label': 'LABEL_1', 'score': 0.9367863535881042},\n", " {'label': 'LABEL_1', 'score': 0.7423374056816101},\n", " {'label': 'LABEL_0', 'score': 0.5968354940414429},\n", " {'label': 'LABEL_0', 'score': 0.9948377013206482},\n", " {'label': 'LABEL_0', 'score': 0.8263901472091675},\n", " {'label': 'LABEL_0', 'score': 0.988389253616333},\n", " {'label': 'LABEL_1', 'score': 0.995772659778595},\n", " {'label': 'LABEL_0', 'score': 0.8763516545295715},\n", " {'label': 'LABEL_0', 'score': 0.9964577555656433},\n", " {'label': 'LABEL_1', 'score': 0.9791182279586792},\n", " {'label': 'LABEL_0', 'score': 0.9895098209381104},\n", " {'label': 'LABEL_0', 'score': 0.8862988948822021},\n", " {'label': 'LABEL_0', 'score': 0.9519324898719788},\n", " {'label': 'LABEL_1', 'score': 0.9921212196350098},\n", " {'label': 'LABEL_1', 'score': 0.6741978526115417},\n", " {'label': 'LABEL_0', 'score': 0.5454573631286621},\n", " {'label': 'LABEL_0', 'score': 0.5335167646408081},\n", " {'label': 'LABEL_1', 'score': 0.5119876265525818},\n", " {'label': 'LABEL_1', 'score': 0.632554829120636},\n", " {'label': 'LABEL_0', 'score': 0.9117580652236938},\n", " {'label': 'LABEL_0', 'score': 0.9932965636253357},\n", " {'label': 'LABEL_0', 'score': 0.9813398718833923},\n", " {'label': 'LABEL_0', 'score': 0.9965149164199829},\n", " {'label': 'LABEL_0', 'score': 0.9989803433418274},\n", " {'label': 'LABEL_1', 'score': 0.9901706576347351},\n", " {'label': 'LABEL_1', 'score': 0.9950527548789978},\n", " {'label': 'LABEL_1', 'score': 0.9674326181411743},\n", " {'label': 'LABEL_1', 'score': 0.9943349957466125},\n", " {'label': 'LABEL_0', 'score': 0.9982045888900757},\n", " {'label': 'LABEL_1', 'score': 0.9984146356582642},\n", " {'label': 'LABEL_0', 'score': 0.9807737469673157},\n", " {'label': 'LABEL_1', 'score': 0.998051643371582},\n", " {'label': 'LABEL_1', 'score': 0.9973867535591125},\n", " {'label': 'LABEL_1', 'score': 0.9974567294120789},\n", " {'label': 'LABEL_1', 'score': 0.9983263611793518},\n", " {'label': 'LABEL_1', 'score': 0.9982770681381226},\n", " {'label': 'LABEL_1', 'score': 0.9981163740158081},\n", " {'label': 'LABEL_1', 'score': 0.9136704802513123},\n", " {'label': 'LABEL_1', 'score': 0.8776158094406128},\n", " {'label': 'LABEL_1', 'score': 0.9977473616600037},\n", " {'label': 'LABEL_1', 'score': 0.9981617331504822},\n", " {'label': 'LABEL_1', 'score': 0.9980959296226501},\n", " {'label': 'LABEL_0', 'score': 0.5451762676239014},\n", " {'label': 'LABEL_1', 'score': 0.5463171601295471},\n", " {'label': 'LABEL_0', 'score': 0.9952754974365234},\n", " {'label': 'LABEL_1', 'score': 0.9979830980300903},\n", " {'label': 'LABEL_1', 'score': 0.9973657727241516},\n", " {'label': 'LABEL_1', 'score': 0.9979636669158936},\n", " {'label': 'LABEL_1', 'score': 0.9980011582374573},\n", " {'label': 'LABEL_0', 'score': 0.609296441078186},\n", " {'label': 'LABEL_1', 'score': 0.9980497360229492},\n", " {'label': 'LABEL_1', 'score': 0.9953547716140747},\n", " {'label': 'LABEL_1', 'score': 0.9984027743339539},\n", " {'label': 'LABEL_1', 'score': 0.9978079199790955},\n", " {'label': 'LABEL_1', 'score': 0.9980959296226501},\n", " {'label': 'LABEL_1', 'score': 0.9979496598243713},\n", " {'label': 'LABEL_1', 'score': 0.9981176853179932},\n", " {'label': 'LABEL_1', 'score': 0.9982970356941223},\n", " {'label': 'LABEL_1', 'score': 0.9970557689666748},\n", " {'label': 'LABEL_1', 'score': 0.9968823194503784},\n", " {'label': 'LABEL_1', 'score': 0.9980416297912598},\n", " {'label': 'LABEL_1', 'score': 0.9300060272216797},\n", " {'label': 'LABEL_1', 'score': 0.9972960352897644},\n", " {'label': 'LABEL_1', 'score': 0.9978334307670593},\n", " {'label': 'LABEL_1', 'score': 0.9977651834487915},\n", " {'label': 'LABEL_1', 'score': 0.9975778460502625},\n", " {'label': 'LABEL_1', 'score': 0.9978219270706177},\n", " {'label': 'LABEL_1', 'score': 0.9968376159667969},\n", " {'label': 'LABEL_1', 'score': 0.9980719089508057},\n", " {'label': 'LABEL_1', 'score': 0.997948944568634},\n", " {'label': 'LABEL_1', 'score': 0.997114896774292},\n", " {'label': 'LABEL_1', 'score': 0.9974072575569153},\n", " {'label': 'LABEL_1', 'score': 0.9976442456245422},\n", " {'label': 'LABEL_1', 'score': 0.9970853924751282},\n", " {'label': 'LABEL_1', 'score': 0.9979702830314636},\n", " {'label': 'LABEL_1', 'score': 0.9972410202026367},\n", " {'label': 'LABEL_1', 'score': 0.9980522394180298},\n", " {'label': 'LABEL_1', 'score': 0.9978965520858765},\n", " {'label': 'LABEL_1', 'score': 0.9981445074081421},\n", " {'label': 'LABEL_1', 'score': 0.9973828196525574},\n", " {'label': 'LABEL_1', 'score': 0.9975137710571289},\n", " {'label': 'LABEL_1', 'score': 0.9973871111869812},\n", " {'label': 'LABEL_1', 'score': 0.9975990653038025},\n", " {'label': 'LABEL_1', 'score': 0.9978379607200623},\n", " {'label': 'LABEL_1', 'score': 0.9977128505706787},\n", " {'label': 'LABEL_1', 'score': 0.9977173805236816},\n", " {'label': 'LABEL_1', 'score': 0.9973612427711487},\n", " {'label': 'LABEL_1', 'score': 0.9964808821678162},\n", " {'label': 'LABEL_1', 'score': 0.9978213310241699},\n", " {'label': 'LABEL_1', 'score': 0.9965618252754211},\n", " {'label': 'LABEL_1', 'score': 0.9972785115242004},\n", " {'label': 'LABEL_1', 'score': 0.9975953698158264},\n", " {'label': 'LABEL_1', 'score': 0.9975850582122803},\n", " {'label': 'LABEL_1', 'score': 0.9974740147590637},\n", " {'label': 'LABEL_1', 'score': 0.996859073638916},\n", " {'label': 'LABEL_1', 'score': 0.997346043586731},\n", " {'label': 'LABEL_1', 'score': 0.9979947805404663},\n", " {'label': 'LABEL_1', 'score': 0.9974393844604492},\n", " {'label': 'LABEL_1', 'score': 0.9974966645240784},\n", " {'label': 'LABEL_1', 'score': 0.9979947805404663},\n", " {'label': 'LABEL_1', 'score': 0.9983198046684265},\n", " {'label': 'LABEL_1', 'score': 0.9958517551422119},\n", " {'label': 'LABEL_1', 'score': 0.9918461441993713},\n", " {'label': 'LABEL_1', 'score': 0.9947928786277771},\n", " {'label': 'LABEL_1', 'score': 0.9898167848587036},\n", " {'label': 'LABEL_1', 'score': 0.9971463084220886},\n", " {'label': 'LABEL_1', 'score': 0.9970598816871643},\n", " {'label': 'LABEL_1', 'score': 0.997458279132843},\n", " {'label': 'LABEL_1', 'score': 0.9978926777839661},\n", " {'label': 'LABEL_1', 'score': 0.997327446937561},\n", " {'label': 'LABEL_1', 'score': 0.9979981780052185},\n", " {'label': 'LABEL_1', 'score': 0.9981091022491455},\n", " {'label': 'LABEL_1', 'score': 0.998021125793457},\n", " {'label': 'LABEL_1', 'score': 0.9979619979858398},\n", " {'label': 'LABEL_1', 'score': 0.9978579878807068},\n", " {'label': 'LABEL_1', 'score': 0.9977995753288269},\n", " {'label': 'LABEL_1', 'score': 0.9970647692680359},\n", " {'label': 'LABEL_1', 'score': 0.9957797527313232},\n", " {'label': 'LABEL_1', 'score': 0.9966141581535339},\n", " {'label': 'LABEL_1', 'score': 0.9957832098007202},\n", " {'label': 'LABEL_1', 'score': 0.9975501894950867},\n", " {'label': 'LABEL_1', 'score': 0.9969447255134583},\n", " {'label': 'LABEL_1', 'score': 0.9977447986602783},\n", " {'label': 'LABEL_1', 'score': 0.9958977103233337},\n", " {'label': 'LABEL_1', 'score': 0.9965307116508484},\n", " {'label': 'LABEL_1', 'score': 0.998040497303009},\n", " {'label': 'LABEL_1', 'score': 0.9974597096443176},\n", " {'label': 'LABEL_1', 'score': 0.9975069165229797},\n", " {'label': 'LABEL_1', 'score': 0.9968619346618652},\n", " {'label': 'LABEL_1', 'score': 0.7999141216278076},\n", " {'label': 'LABEL_1', 'score': 0.9966764450073242},\n", " {'label': 'LABEL_1', 'score': 0.9978440999984741},\n", " {'label': 'LABEL_1', 'score': 0.9975869655609131},\n", " {'label': 'LABEL_1', 'score': 0.9980499744415283},\n", " {'label': 'LABEL_1', 'score': 0.9977290034294128},\n", " {'label': 'LABEL_0', 'score': 0.9947840571403503},\n", " {'label': 'LABEL_1', 'score': 0.6376346349716187},\n", " {'label': 'LABEL_1', 'score': 0.9190598726272583},\n", " {'label': 'LABEL_0', 'score': 0.9695489406585693},\n", " {'label': 'LABEL_0', 'score': 0.9246589541435242},\n", " {'label': 'LABEL_1', 'score': 0.5756272673606873},\n", " {'label': 'LABEL_0', 'score': 0.8085498213768005},\n", " {'label': 'LABEL_0', 'score': 0.9974361062049866},\n", " {'label': 'LABEL_1', 'score': 0.7996165752410889},\n", " {'label': 'LABEL_0', 'score': 0.9922124743461609},\n", " {'label': 'LABEL_1', 'score': 0.9789844155311584},\n", " {'label': 'LABEL_0', 'score': 0.9910444021224976},\n", " {'label': 'LABEL_0', 'score': 0.9960230588912964},\n", " {'label': 'LABEL_0', 'score': 0.9831807017326355},\n", " {'label': 'LABEL_0', 'score': 0.9928321838378906},\n", " {'label': 'LABEL_0', 'score': 0.9855900406837463},\n", " {'label': 'LABEL_0', 'score': 0.9944981932640076},\n", " {'label': 'LABEL_0', 'score': 0.9949136972427368},\n", " {'label': 'LABEL_0', 'score': 0.9957302212715149},\n", " {'label': 'LABEL_0', 'score': 0.9968340992927551},\n", " {'label': 'LABEL_0', 'score': 0.9956672191619873},\n", " {'label': 'LABEL_1', 'score': 0.9852407574653625},\n", " {'label': 'LABEL_0', 'score': 0.9968200922012329},\n", " {'label': 'LABEL_0', 'score': 0.9808366298675537},\n", " {'label': 'LABEL_0', 'score': 0.9898589253425598},\n", " {'label': 'LABEL_0', 'score': 0.9969015121459961},\n", " {'label': 'LABEL_0', 'score': 0.9971805810928345},\n", " {'label': 'LABEL_0', 'score': 0.9982355833053589},\n", " {'label': 'LABEL_0', 'score': 0.9987756609916687},\n", " {'label': 'LABEL_0', 'score': 0.9926889538764954},\n", " {'label': 'LABEL_0', 'score': 0.9980269074440002},\n", " {'label': 'LABEL_0', 'score': 0.995913565158844},\n", " {'label': 'LABEL_0', 'score': 0.9521064162254333},\n", " {'label': 'LABEL_0', 'score': 0.9973504543304443},\n", " {'label': 'LABEL_0', 'score': 0.9951818585395813},\n", " {'label': 'LABEL_0', 'score': 0.9943795800209045},\n", " {'label': 'LABEL_0', 'score': 0.993586003780365},\n", " {'label': 'LABEL_0', 'score': 0.996479332447052},\n", " {'label': 'LABEL_0', 'score': 0.997653067111969},\n", " {'label': 'LABEL_0', 'score': 0.9808756113052368},\n", " {'label': 'LABEL_0', 'score': 0.9970345497131348},\n", " {'label': 'LABEL_0', 'score': 0.9431267380714417},\n", " {'label': 'LABEL_0', 'score': 0.9003729224205017},\n", " {'label': 'LABEL_0', 'score': 0.9930608868598938},\n", " {'label': 'LABEL_1', 'score': 0.7583749294281006},\n", " {'label': 'LABEL_0', 'score': 0.9948474168777466},\n", " {'label': 'LABEL_1', 'score': 0.5063807964324951},\n", " {'label': 'LABEL_0', 'score': 0.9589183330535889},\n", " {'label': 'LABEL_1', 'score': 0.9781742095947266},\n", " {'label': 'LABEL_0', 'score': 0.9977648258209229},\n", " {'label': 'LABEL_0', 'score': 0.9987187385559082},\n", " {'label': 'LABEL_0', 'score': 0.9974794983863831},\n", " {'label': 'LABEL_0', 'score': 0.9986817240715027},\n", " {'label': 'LABEL_0', 'score': 0.996134877204895},\n", " {'label': 'LABEL_0', 'score': 0.9976504445075989},\n", " {'label': 'LABEL_0', 'score': 0.9987467527389526},\n", " {'label': 'LABEL_0', 'score': 0.9985254406929016},\n", " {'label': 'LABEL_0', 'score': 0.9932849407196045},\n", " {'label': 'LABEL_0', 'score': 0.5109266638755798},\n", " {'label': 'LABEL_0', 'score': 0.9801786541938782},\n", " {'label': 'LABEL_0', 'score': 0.9933167695999146},\n", " {'label': 'LABEL_0', 'score': 0.9007285833358765},\n", " {'label': 'LABEL_0', 'score': 0.986695408821106},\n", " {'label': 'LABEL_0', 'score': 0.9742431640625},\n", " {'label': 'LABEL_1', 'score': 0.7220762372016907},\n", " {'label': 'LABEL_0', 'score': 0.9923557043075562},\n", " {'label': 'LABEL_0', 'score': 0.9706437587738037},\n", " {'label': 'LABEL_0', 'score': 0.9988356232643127},\n", " {'label': 'LABEL_0', 'score': 0.9987884163856506},\n", " {'label': 'LABEL_0', 'score': 0.9797398447990417},\n", " {'label': 'LABEL_0', 'score': 0.987711489200592},\n", " {'label': 'LABEL_0', 'score': 0.9961549639701843},\n", " {'label': 'LABEL_0', 'score': 0.6292486190795898},\n", " {'label': 'LABEL_0', 'score': 0.9972221851348877},\n", " {'label': 'LABEL_0', 'score': 0.9987063407897949},\n", " {'label': 'LABEL_0', 'score': 0.998217761516571},\n", " {'label': 'LABEL_0', 'score': 0.995513379573822},\n", " {'label': 'LABEL_0', 'score': 0.9983007311820984},\n", " {'label': 'LABEL_0', 'score': 0.9979398846626282},\n", " {'label': 'LABEL_0', 'score': 0.9986240863800049},\n", " {'label': 'LABEL_0', 'score': 0.9977601766586304},\n", " {'label': 'LABEL_0', 'score': 0.9974849224090576},\n", " {'label': 'LABEL_0', 'score': 0.949386715888977},\n", " {'label': 'LABEL_0', 'score': 0.969242513179779},\n", " {'label': 'LABEL_0', 'score': 0.978050708770752},\n", " {'label': 'LABEL_0', 'score': 0.9853933453559875},\n", " {'label': 'LABEL_0', 'score': 0.9960567951202393},\n", " {'label': 'LABEL_0', 'score': 0.9965749382972717},\n", " {'label': 'LABEL_0', 'score': 0.9987173080444336},\n", " {'label': 'LABEL_0', 'score': 0.9980658888816833},\n", " {'label': 'LABEL_0', 'score': 0.9988677501678467},\n", " {'label': 'LABEL_0', 'score': 0.9987745881080627},\n", " {'label': 'LABEL_0', 'score': 0.9977912902832031},\n", " {'label': 'LABEL_1', 'score': 0.8811283707618713},\n", " {'label': 'LABEL_0', 'score': 0.9986131191253662},\n", " {'label': 'LABEL_0', 'score': 0.9301124811172485},\n", " {'label': 'LABEL_0', 'score': 0.9980668425559998},\n", " {'label': 'LABEL_1', 'score': 0.8134693503379822},\n", " {'label': 'LABEL_0', 'score': 0.9951452612876892},\n", " {'label': 'LABEL_0', 'score': 0.9973406195640564},\n", " {'label': 'LABEL_0', 'score': 0.9889533519744873},\n", " {'label': 'LABEL_0', 'score': 0.9961967468261719},\n", " {'label': 'LABEL_0', 'score': 0.9389035105705261},\n", " {'label': 'LABEL_0', 'score': 0.9983087778091431},\n", " {'label': 'LABEL_0', 'score': 0.9805503487586975},\n", " {'label': 'LABEL_0', 'score': 0.9979610443115234},\n", " {'label': 'LABEL_0', 'score': 0.9981979727745056},\n", " {'label': 'LABEL_0', 'score': 0.9982859492301941},\n", " {'label': 'LABEL_0', 'score': 0.9986773133277893},\n", " {'label': 'LABEL_0', 'score': 0.9940134882926941},\n", " {'label': 'LABEL_0', 'score': 0.9958133101463318},\n", " {'label': 'LABEL_0', 'score': 0.9986673593521118},\n", " {'label': 'LABEL_0', 'score': 0.9969382286071777},\n", " {'label': 'LABEL_0', 'score': 0.9913495779037476},\n", " {'label': 'LABEL_0', 'score': 0.9988558292388916},\n", " {'label': 'LABEL_0', 'score': 0.9801749587059021},\n", " {'label': 'LABEL_1', 'score': 0.9942229390144348},\n", " {'label': 'LABEL_0', 'score': 0.9904804825782776},\n", " {'label': 'LABEL_0', 'score': 0.9985619187355042},\n", " {'label': 'LABEL_0', 'score': 0.9981953501701355},\n", " {'label': 'LABEL_0', 'score': 0.9989521503448486},\n", " {'label': 'LABEL_0', 'score': 0.9985705614089966},\n", " {'label': 'LABEL_0', 'score': 0.8995286822319031},\n", " {'label': 'LABEL_0', 'score': 0.998676598072052},\n", " {'label': 'LABEL_0', 'score': 0.9984286427497864},\n", " {'label': 'LABEL_0', 'score': 0.6487972140312195},\n", " {'label': 'LABEL_0', 'score': 0.9394513368606567},\n", " ...]" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions" ] }, { "cell_type": "code", "execution_count": 36, "id": "3f35d373-a3a4-4186-8663-d320edb22092", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T15:43:24.194303Z", "iopub.status.busy": "2024-05-23T15:43:24.193637Z", "iopub.status.idle": "2024-05-23T15:43:24.224829Z", "shell.execute_reply": "2024-05-23T15:43:24.224212Z", "shell.execute_reply.started": "2024-05-23T15:43:24.194276Z" } }, "outputs": [], "source": [ "import numpy as np\n", "import evaluate\n", "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 42, "id": "80515948-0db5-466e-8aa6-1aeb53e79cb6", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T15:46:26.164515Z", "iopub.status.busy": "2024-05-23T15:46:26.164217Z", "iopub.status.idle": "2024-05-23T15:46:26.169211Z", "shell.execute_reply": "2024-05-23T15:46:26.168571Z", "shell.execute_reply.started": "2024-05-23T15:46:26.164488Z" } }, "outputs": [], "source": [ "preds = [int(x[\"label\"].split(\"_\")[1]) for x in predictions]" ] }, { "cell_type": "code", "execution_count": 63, "id": "b97ff6d7-6460-4ab9-b163-680a47135cbc", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:07:07.653818Z", "iopub.status.busy": "2024-05-23T16:07:07.653404Z", "iopub.status.idle": "2024-05-23T16:07:07.659790Z", "shell.execute_reply": "2024-05-23T16:07:07.658732Z", "shell.execute_reply.started": "2024-05-23T16:07:07.653793Z" } }, "outputs": [], "source": [ "predictions_proba = np.array([x[\"score\"] if int(x[\"label\"].split(\"_\")[1]) == 1 else (1.0 - x[\"score\"]) for x in predictions])" ] }, { "cell_type": "code", "execution_count": 64, "id": "ffa8d866-1174-4564-8776-468e7a69a0dd", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:07:12.709461Z", "iopub.status.busy": "2024-05-23T16:07:12.709136Z", "iopub.status.idle": "2024-05-23T16:07:12.719229Z", "shell.execute_reply": "2024-05-23T16:07:12.718143Z", "shell.execute_reply.started": "2024-05-23T16:07:12.709436Z" } }, "outputs": [], "source": [ "predictions_probs = np.array([np.array([x[\"score\"], (1.0 - x[\"score\"])]) if int(x[\"label\"].split(\"_\")[1]) == 0 else np.array([(1.0 - x[\"score\"]), x[\"score\"]]) for x in predictions])" ] }, { "cell_type": "code", "execution_count": 65, "id": "f4e20d3c-241a-4c08-adf1-527a1e3fa19c", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:07:14.013838Z", "iopub.status.busy": "2024-05-23T16:07:14.013485Z", "iopub.status.idle": "2024-05-23T16:07:14.021152Z", "shell.execute_reply": "2024-05-23T16:07:14.019934Z", "shell.execute_reply.started": "2024-05-23T16:07:14.013813Z" } }, "outputs": [ { "data": { "text/plain": [ "(1, 1, 0.9920824766159058, array([0.00791752, 0.99208248]))" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds[0], y_test[0], predictions_proba[0], predictions_probs[0]" ] }, { "cell_type": "code", "execution_count": 66, "id": "553f215e-cb9c-45be-8e09-33c840a1c095", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:07:20.087631Z", "iopub.status.busy": "2024-05-23T16:07:20.086457Z", "iopub.status.idle": "2024-05-23T16:07:20.111547Z", "shell.execute_reply": "2024-05-23T16:07:20.110831Z", "shell.execute_reply.started": "2024-05-23T16:07:20.087591Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", "non-crystallizable 0.73 0.97 0.83 1000\n", " crystallizable 0.94 0.60 0.73 898\n", "\n", " accuracy 0.79 1898\n", " macro avg 0.83 0.78 0.78 1898\n", " weighted avg 0.83 0.79 0.78 1898\n", "\n", "[[966 34]\n", " [362 536]]\n", "0.9467594654788418\n" ] } ], "source": [ "from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score\n", "target_names = ['non-crystallizable', 'crystallizable']\n", "\n", "print(classification_report(y_test, preds, target_names=target_names))\n", "print(confusion_matrix(y_test, preds))\n", "print(roc_auc_score(y_test, predictions_proba))" ] }, { "cell_type": "code", "execution_count": 67, "id": "44be3b34-5e77-4814-ac8d-682785bf58b4", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:07:23.695122Z", "iopub.status.busy": "2024-05-23T16:07:23.694785Z", "iopub.status.idle": "2024-05-23T16:07:23.702798Z", "shell.execute_reply": "2024-05-23T16:07:23.702046Z", "shell.execute_reply.started": "2024-05-23T16:07:23.695098Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9467594654788418\n" ] } ], "source": [ "print(roc_auc_score(y_test, predictions_proba))" ] }, { "cell_type": "code", "execution_count": 59, "id": "ea12fbb0-2105-473d-9774-833d9f987a12", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:00:26.155037Z", "iopub.status.busy": "2024-05-23T16:00:26.154548Z", "iopub.status.idle": "2024-05-23T16:00:26.166013Z", "shell.execute_reply": "2024-05-23T16:00:26.165322Z", "shell.execute_reply.started": "2024-05-23T16:00:26.155000Z" } }, "outputs": [], "source": [ "from sklearn.metrics import roc_curve, auc\n", "n_classes = 2\n", "# Compute ROC curve and ROC area for each class\n", "fpr = dict()\n", "tpr = dict()\n", "roc_auc = dict()\n", "for i in range(n_classes):\n", " fpr[i], tpr[i], _ = roc_curve(y_test, predictions_probs[:, i])\n", " roc_auc[i] = auc(fpr[i], tpr[i])\n", "\n", "# Compute micro-average ROC curve and ROC area\n", "fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(y_test, predictions_proba)\n", "roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])" ] }, { "cell_type": "code", "execution_count": 60, "id": "5775d0d0-b46a-4b77-8610-a2b9b6ed1cd7", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:01:15.759696Z", "iopub.status.busy": "2024-05-23T16:01:15.759387Z", "iopub.status.idle": "2024-05-23T16:01:15.996034Z", "shell.execute_reply": "2024-05-23T16:01:15.995052Z", "shell.execute_reply.started": "2024-05-23T16:01:15.759673Z" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "plt.figure()\n", "lw = 2\n", "plt.plot(\n", " fpr[1],\n", " tpr[1],\n", " color=\"darkorange\",\n", " lw=lw,\n", " label=\"ROC curve (area = %0.2f)\" % roc_auc[1],\n", ")\n", "plt.plot([0, 1], [0, 1], color=\"navy\", lw=lw, linestyle=\"--\")\n", "plt.xlim([0.0, 1.0])\n", "plt.ylim([0.0, 1.05])\n", "plt.xlabel(\"False Positive Rate\")\n", "plt.ylabel(\"True Positive Rate\")\n", "plt.title(\"Receiver operating characteristic example\")\n", "plt.legend(loc=\"lower right\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 74, "id": "54717414-d116-4dd3-9f70-2e823986dc46", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:14:49.010576Z", "iopub.status.busy": "2024-05-23T16:14:49.009575Z", "iopub.status.idle": "2024-05-23T16:14:49.211862Z", "shell.execute_reply": "2024-05-23T16:14:49.211110Z", "shell.execute_reply.started": "2024-05-23T16:14:49.010549Z" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import PrecisionRecallDisplay\n", "\n", "import matplotlib.pyplot as plt\n", "fig, ax = plt.subplots()\n", "display = PrecisionRecallDisplay.from_predictions(y_test, predictions_proba, name=\"ESMCrystal\", ax=ax)\n", "_ = display.ax_.set_title(\"2-class Precision-Recall curve\")\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": 88, "id": "1fd70e82-cbd6-4399-b9ea-8bbace1f34c8", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:32:02.066439Z", "iopub.status.busy": "2024-05-23T16:32:02.065781Z", "iopub.status.idle": "2024-05-23T16:32:02.080057Z", "shell.execute_reply": "2024-05-23T16:32:02.079358Z", "shell.execute_reply.started": "2024-05-23T16:32:02.066411Z" } }, "outputs": [], "source": [ "import csv\n", "\n", "testdatacsvfilepath = \"Datasets/BCrystal_Balanced_Test_set/test.fasta\"\n", "testcsvfilepath = \"Datasets/BCrystal_Balanced_Test_set/y_test.csv\"\n", "\n", "X_test_B = []\n", "y_test_B = []\n", "\n", "with open(testdatacsvfilepath) as testcsvfile:\n", " csvreader = csv.reader(testcsvfile)\n", " for row in csvreader:\n", " #print(row)\n", " if '>' not in row[0]:\n", " X_test_B.append(row[0])\n", " else:\n", " pass\n", " \n", "with open(testcsvfilepath) as testcsvfile:\n", " csvreader = csv.reader(testcsvfile)\n", " for row in csvreader:\n", " #print(row)\n", " if '1' in row:\n", " y_test_B.append(1)\n", " elif '0' in row:\n", " y_test_B.append(0)\n", " else:\n", " pass" ] }, { "cell_type": "code", "execution_count": 89, "id": "f90b2e57-0b59-476f-9d52-6682f14d26f4", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:32:04.442180Z", "iopub.status.busy": "2024-05-23T16:32:04.441861Z", "iopub.status.idle": "2024-05-23T16:32:04.449920Z", "shell.execute_reply": "2024-05-23T16:32:04.449064Z", "shell.execute_reply.started": "2024-05-23T16:32:04.442155Z" } }, "outputs": [ { "data": { "text/plain": [ "('MRVLFIGDVFGQPGRRVLQNHLPTIRPQFDFVIVNMENSAGGFGMHRDAARGALEAGAGCLTLGNHAWHHKDIYPMLSEDTYPIVRPLNYADPGTPGVGWRTFDVNGEKLTVVNLLGRVFMEAVDNPFRTMDALLERDDLGTVFVDFHAEATSEKEAMGWHLAGRVAAVIGTHTHVPTADTRILKGGTAYQTDAGFTGPHDSIIGSAIEGPLQRFLTERPHRYGVAEGRAELNGVALHFEGGKATAAERYRFIED',\n", " 255,\n", " 1787,\n", " 1,\n", " 1787)" ] }, "execution_count": 89, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_test_B[0], len(X_test_B[0]), len(X_test_B), y_test_B[0], len(y_test_B)" ] }, { "cell_type": "code", "execution_count": 91, "id": "fa5ad501-5962-41b1-90ff-75797f683cc7", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:34:49.007157Z", "iopub.status.busy": "2024-05-23T16:34:49.006603Z", "iopub.status.idle": "2024-05-23T16:34:49.014441Z", "shell.execute_reply": "2024-05-23T16:34:49.013695Z", "shell.execute_reply.started": "2024-05-23T16:34:49.007131Z" } }, "outputs": [], "source": [ "testdatacsvfilepath = \"Datasets/SP_Final_set/FULL_SP.fasta\"\n", "testcsvfilepath = \"Datasets/SP_Final_set/SP_True_Label.csv\"\n", "\n", "X_test_S = []\n", "y_test_S = []\n", "\n", "with open(testdatacsvfilepath) as testcsvfile:\n", " csvreader = csv.reader(testcsvfile)\n", " for row in csvreader:\n", " #print(row)\n", " if '>' not in row[0]:\n", " X_test_S.append(row[0])\n", " else:\n", " pass\n", " \n", "with open(testcsvfilepath) as testcsvfile:\n", " csvreader = csv.reader(testcsvfile)\n", " for row in csvreader:\n", " #print(row)\n", " if '1' in row:\n", " y_test_S.append(1)\n", " elif '0' in row:\n", " y_test_S.append(0)\n", " else:\n", " pass" ] }, { "cell_type": "code", "execution_count": 92, "id": "b4827864-c65d-4b98-8e15-8c4582436251", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:34:53.261113Z", "iopub.status.busy": "2024-05-23T16:34:53.260231Z", "iopub.status.idle": "2024-05-23T16:34:53.267569Z", "shell.execute_reply": "2024-05-23T16:34:53.266331Z", "shell.execute_reply.started": "2024-05-23T16:34:53.261084Z" } }, "outputs": [ { "data": { "text/plain": [ "('MVDMQSLDEEDFSVSKSSDADAEFDIVIGNIEDIIMEDEFQHLQQSFMEKYYLEFDDSEENKLSYTPIFNEYIEILEKHLEQQLVERIPGFNMDAFTHSLKQHKDEVSGDILDMLLTFTDFMAFKEMFTDYRAEKEGRGLDLSTGLVVKSLNSSSASPLTPSMASQSI',\n", " 168,\n", " 237,\n", " 1,\n", " 237)" ] }, "execution_count": 92, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_test_S[0], len(X_test_S[0]), len(X_test_S), y_test_S[0], len(y_test_S)" ] }, { "cell_type": "code", "execution_count": 93, "id": "c2d8ff10-751e-4751-a8f4-0e2974b2c647", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:35:30.446881Z", "iopub.status.busy": "2024-05-23T16:35:30.445742Z", "iopub.status.idle": "2024-05-23T16:35:30.458738Z", "shell.execute_reply": "2024-05-23T16:35:30.457975Z", "shell.execute_reply.started": "2024-05-23T16:35:30.446841Z" } }, "outputs": [], "source": [ "testdatacsvfilepath = \"Datasets/TR_Final_set/FULL_TR.fasta\"\n", "testcsvfilepath = \"Datasets/TR_Final_set/TR_True_Label.csv\"\n", "\n", "X_test_T = []\n", "y_test_T = []\n", "\n", "with open(testdatacsvfilepath) as testcsvfile:\n", " csvreader = csv.reader(testcsvfile)\n", " for row in csvreader:\n", " #print(row)\n", " if '>' not in row[0]:\n", " X_test_T.append(row[0])\n", " else:\n", " pass\n", " \n", "with open(testcsvfilepath) as testcsvfile:\n", " csvreader = csv.reader(testcsvfile)\n", " for row in csvreader:\n", " #print(row)\n", " if '1' in row:\n", " y_test_T.append(1)\n", " elif '0' in row:\n", " y_test_T.append(0)\n", " else:\n", " pass" ] }, { "cell_type": "code", "execution_count": 94, "id": "1b96e5c6-26cc-4659-8f2b-3cff26550262", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:35:32.595636Z", "iopub.status.busy": "2024-05-23T16:35:32.595055Z", "iopub.status.idle": "2024-05-23T16:35:32.601035Z", "shell.execute_reply": "2024-05-23T16:35:32.600211Z", "shell.execute_reply.started": "2024-05-23T16:35:32.595610Z" } }, "outputs": [ { "data": { "text/plain": [ "('MRVLFIGDVFGQPGRRVLQNHLPTIRPQFDFVIVNMENSAGGFGMHRDAARGALEAGAGCLTLGNHAWHHKDIYPMLSEDTYPIVRPLNYADPGTPGVGWRTFDVNGEKLTVVNLLGRVFMEAVDNPFRTMDALLERDDLGTVFVDFHAEATSEKEAMGWHLAGRVAAVIGTHTHVPTADTRILKGGTAYQTDAGFTGPHDSIIGSAIEGPLQRFLTERPHRYGVAEGRAELNGVALHFEGGKATAAERYRFIED',\n", " 255,\n", " 1012,\n", " 1,\n", " 1012)" ] }, "execution_count": 94, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_test_T[0], len(X_test_T[0]), len(X_test_T), y_test_T[0], len(y_test_T)" ] }, { "cell_type": "code", "execution_count": 95, "id": "ca0e0eb8-6593-4e3f-9175-c557be423fbc", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:41:50.836321Z", "iopub.status.busy": "2024-05-23T16:41:50.835798Z", "iopub.status.idle": "2024-05-23T16:42:48.961559Z", "shell.execute_reply": "2024-05-23T16:42:48.960734Z", "shell.execute_reply.started": "2024-05-23T16:41:50.836248Z" } }, "outputs": [], "source": [ "predictions_D = pipeline(X_test)\n", "predictions_B = pipeline(X_test_B)\n", "predictions_S = pipeline(X_test_S)\n", "predictions_T = pipeline(X_test_T)\n", "\n", "preds_D = [int(x[\"label\"].split(\"_\")[1]) for x in predictions_D]\n", "predictions_D_proba = np.array([x[\"score\"] if int(x[\"label\"].split(\"_\")[1]) == 1 else (1.0 - x[\"score\"]) for x in predictions_D])\n", "predictions_D_probs = np.array([np.array([x[\"score\"], (1.0 - x[\"score\"])]) if int(x[\"label\"].split(\"_\")[1]) == 0 else np.array([(1.0 - x[\"score\"]), x[\"score\"]]) for x in predictions_D])\n", "\n", "preds_B = [int(x[\"label\"].split(\"_\")[1]) for x in predictions_B]\n", "predictions_B_proba = np.array([x[\"score\"] if int(x[\"label\"].split(\"_\")[1]) == 1 else (1.0 - x[\"score\"]) for x in predictions_B])\n", "predictions_B_probs = np.array([np.array([x[\"score\"], (1.0 - x[\"score\"])]) if int(x[\"label\"].split(\"_\")[1]) == 0 else np.array([(1.0 - x[\"score\"]), x[\"score\"]]) for x in predictions_B])\n", "\n", "preds_S = [int(x[\"label\"].split(\"_\")[1]) for x in predictions_S]\n", "predictions_S_proba = np.array([x[\"score\"] if int(x[\"label\"].split(\"_\")[1]) == 1 else (1.0 - x[\"score\"]) for x in predictions_S])\n", "predictions_S_probs = np.array([np.array([x[\"score\"], (1.0 - x[\"score\"])]) if int(x[\"label\"].split(\"_\")[1]) == 0 else np.array([(1.0 - x[\"score\"]), x[\"score\"]]) for x in predictions_S])\n", "\n", "preds_T = [int(x[\"label\"].split(\"_\")[1]) for x in predictions_T]\n", "predictions_T_proba = np.array([x[\"score\"] if int(x[\"label\"].split(\"_\")[1]) == 1 else (1.0 - x[\"score\"]) for x in predictions_T])\n", "predictions_T_probs = np.array([np.array([x[\"score\"], (1.0 - x[\"score\"])]) if int(x[\"label\"].split(\"_\")[1]) == 0 else np.array([(1.0 - x[\"score\"]), x[\"score\"]]) for x in predictions_T])" ] }, { "cell_type": "code", "execution_count": 96, "id": "29bbcc1a-5200-493c-9332-759e7828c1a4", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:42:48.963024Z", "iopub.status.busy": "2024-05-23T16:42:48.962793Z", "iopub.status.idle": "2024-05-23T16:42:48.982858Z", "shell.execute_reply": "2024-05-23T16:42:48.982175Z", "shell.execute_reply.started": "2024-05-23T16:42:48.963000Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", "non-crystallizable 0.73 0.97 0.83 1000\n", " crystallizable 0.94 0.60 0.73 898\n", "\n", " accuracy 0.79 1898\n", " macro avg 0.83 0.78 0.78 1898\n", " weighted avg 0.83 0.79 0.78 1898\n", "\n", "[[966 34]\n", " [362 536]]\n", "0.9467594654788418\n" ] } ], "source": [ "print(classification_report(y_test, preds_D, target_names=target_names))\n", "print(confusion_matrix(y_test, preds_D))\n", "print(roc_auc_score(y_test, predictions_D_proba))" ] }, { "cell_type": "code", "execution_count": 97, "id": "5dc3d485-773f-416a-a50b-97e862dece61", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:43:03.684115Z", "iopub.status.busy": "2024-05-23T16:43:03.683465Z", "iopub.status.idle": "2024-05-23T16:43:03.702174Z", "shell.execute_reply": "2024-05-23T16:43:03.701390Z", "shell.execute_reply.started": "2024-05-23T16:43:03.684072Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", "non-crystallizable 0.71 0.97 0.82 896\n", " crystallizable 0.94 0.60 0.73 891\n", "\n", " accuracy 0.78 1787\n", " macro avg 0.83 0.78 0.77 1787\n", " weighted avg 0.83 0.78 0.77 1787\n", "\n", "[[865 31]\n", " [360 531]]\n", "0.9465463163379829\n" ] } ], "source": [ "print(classification_report(y_test_B, preds_B, target_names=target_names))\n", "print(confusion_matrix(y_test_B, preds_B))\n", "print(roc_auc_score(y_test_B, predictions_B_proba))" ] }, { "cell_type": "code", "execution_count": 98, "id": "30b6d080-31d3-42ab-a77b-08dd624e610d", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:43:26.062153Z", "iopub.status.busy": "2024-05-23T16:43:26.061551Z", "iopub.status.idle": "2024-05-23T16:43:26.079462Z", "shell.execute_reply": "2024-05-23T16:43:26.078773Z", "shell.execute_reply.started": "2024-05-23T16:43:26.062128Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", "non-crystallizable 0.56 0.96 0.70 89\n", " crystallizable 0.95 0.54 0.69 148\n", "\n", " accuracy 0.70 237\n", " macro avg 0.75 0.75 0.70 237\n", " weighted avg 0.80 0.70 0.69 237\n", "\n", "[[85 4]\n", " [68 80]]\n", "0.9328120255086547\n" ] } ], "source": [ "print(classification_report(y_test_S, preds_S, target_names=target_names))\n", "print(confusion_matrix(y_test_S, preds_S))\n", "print(roc_auc_score(y_test_S, predictions_S_proba))" ] }, { "cell_type": "code", "execution_count": 99, "id": "814222c8-242d-420f-a314-e210fb7a5103", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:43:50.482502Z", "iopub.status.busy": "2024-05-23T16:43:50.481734Z", "iopub.status.idle": "2024-05-23T16:43:50.498927Z", "shell.execute_reply": "2024-05-23T16:43:50.498016Z", "shell.execute_reply.started": "2024-05-23T16:43:50.482475Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", "non-crystallizable 0.79 0.97 0.87 638\n", " crystallizable 0.93 0.55 0.69 374\n", "\n", " accuracy 0.82 1012\n", " macro avg 0.86 0.76 0.78 1012\n", " weighted avg 0.84 0.82 0.81 1012\n", "\n", "[[622 16]\n", " [167 207]]\n", "0.9562804888270497\n" ] } ], "source": [ "print(classification_report(y_test_T, preds_T, target_names=target_names))\n", "print(confusion_matrix(y_test_T, preds_T))\n", "print(roc_auc_score(y_test_T, predictions_T_proba))" ] }, { "cell_type": "code", "execution_count": 100, "id": "9ee5ca85-2247-40c6-90a6-4eaad2289469", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T16:50:49.394652Z", "iopub.status.busy": "2024-05-23T16:50:49.393793Z", "iopub.status.idle": "2024-05-23T16:50:49.421009Z", "shell.execute_reply": "2024-05-23T16:50:49.419986Z", "shell.execute_reply.started": "2024-05-23T16:50:49.394609Z" } }, "outputs": [], "source": [ "from sklearn.metrics import roc_curve, auc\n", "n_classes = 2\n", "\n", "# Compute ROC curve and ROC area for each class\n", "fpr_D = dict()\n", "tpr_D = dict()\n", "roc_auc_D = dict()\n", "for i in range(n_classes):\n", " fpr_D[i], tpr_D[i], _ = roc_curve(y_test, predictions_D_probs[:, i])\n", " roc_auc_D[i] = auc(fpr_D[i], tpr_D[i])\n", "\n", "# Compute micro-average ROC curve and ROC area\n", "fpr_D[\"micro\"], tpr_D[\"micro\"], _ = roc_curve(y_test, predictions_D_proba)\n", "roc_auc_D[\"micro\"] = auc(fpr_D[\"micro\"], tpr_D[\"micro\"])\n", "\n", "# Compute ROC curve and ROC area for each class\n", "fpr_B = dict()\n", "tpr_B = dict()\n", "roc_auc_B = dict()\n", "for i in range(n_classes):\n", " fpr_B[i], tpr_B[i], _ = roc_curve(y_test_B, predictions_B_probs[:, i])\n", " roc_auc_B[i] = auc(fpr_B[i], tpr_B[i])\n", "\n", "# Compute micro-average ROC curve and ROC area\n", "fpr_B[\"micro\"], tpr_B[\"micro\"], _ = roc_curve(y_test_B, predictions_B_proba)\n", "roc_auc_B[\"micro\"] = auc(fpr_B[\"micro\"], tpr_B[\"micro\"])\n", "\n", "# Compute ROC curve and ROC area for each class\n", "fpr_S = dict()\n", "tpr_S = dict()\n", "roc_auc_S = dict()\n", "for i in range(n_classes):\n", " fpr_S[i], tpr_S[i], _ = roc_curve(y_test_S, predictions_S_probs[:, i])\n", " roc_auc_S[i] = auc(fpr_S[i], tpr_S[i])\n", "\n", "# Compute micro-average ROC curve and ROC area\n", "fpr_S[\"micro\"], tpr_S[\"micro\"], _ = roc_curve(y_test_S, predictions_S_proba)\n", "roc_auc_S[\"micro\"] = auc(fpr_S[\"micro\"], tpr_S[\"micro\"])\n", "\n", "# Compute ROC curve and ROC area for each class\n", "fpr_T = dict()\n", "tpr_T = dict()\n", "roc_auc_T = dict()\n", "for i in range(n_classes):\n", " fpr_T[i], tpr_T[i], _ = roc_curve(y_test_T, predictions_T_probs[:, i])\n", " roc_auc_T[i] = auc(fpr_T[i], tpr_T[i])\n", "\n", "# Compute micro-average ROC curve and ROC area\n", "fpr_T[\"micro\"], tpr_T[\"micro\"], _ = roc_curve(y_test_T, predictions_T_proba)\n", "roc_auc_T[\"micro\"] = auc(fpr_T[\"micro\"], tpr_T[\"micro\"])" ] }, { "cell_type": "code", "execution_count": 114, "id": "4d9e97a8-38cc-4b7c-abfc-aa0b49d22ddd", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T17:21:05.329932Z", "iopub.status.busy": "2024-05-23T17:21:05.329083Z", "iopub.status.idle": "2024-05-23T17:21:05.547845Z", "shell.execute_reply": "2024-05-23T17:21:05.546666Z", "shell.execute_reply.started": "2024-05-23T17:21:05.329889Z" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "plt.figure()\n", "lw = 2\n", "\n", "plt.plot(\n", " fpr_D[1],\n", " tpr_D[1],\n", " color=\"navy\",\n", " lw=lw,\n", " label=\"DeepCrystal Test ROC curve (area = %0.2f)\" % roc_auc_D[1],\n", ")\n", "\n", "plt.plot(\n", " fpr_B[1],\n", " tpr_B[1],\n", " color=\"orange\",\n", " lw=lw,\n", " label=\"Balanced Test ROC curve (area = %0.2f)\" % roc_auc_B[1],\n", ")\n", "\n", "plt.plot(\n", " fpr_S[1],\n", " tpr_S[1],\n", " color=\"lightgreen\",\n", " lw=lw,\n", " label=\"SP ROC curve (area = %0.2f)\" % roc_auc_S[1],\n", ")\n", "\n", "plt.plot(\n", " fpr_T[1],\n", " tpr_T[1],\n", " color=\"red\",\n", " lw=lw,\n", " label=\"TR ROC curve (area = %0.2f)\" % roc_auc_T[1],\n", ")\n", "\n", "plt.plot([0, 1], [0, 1], color=\"grey\", lw=lw, linestyle=\"--\")\n", "plt.xlim([0.0, 1.0])\n", "plt.ylim([0.0, 1.05])\n", "plt.xlabel(\"False Positive Rate\")\n", "plt.ylabel(\"True Positive Rate\")\n", "plt.title(\"Receiver operating characteristic - ESMCrystal_t6_8M_v1\")\n", "plt.legend(loc=\"lower right\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 108, "id": "3d217ed3-a4ca-4140-aeb3-67ae71d7b157", "metadata": { "execution": { "iopub.execute_input": "2024-05-23T17:19:05.198574Z", "iopub.status.busy": "2024-05-23T17:19:05.197828Z", "iopub.status.idle": "2024-05-23T17:19:05.432540Z", "shell.execute_reply": "2024-05-23T17:19:05.431766Z", "shell.execute_reply.started": "2024-05-23T17:19:05.198548Z" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import PrecisionRecallDisplay\n", "\n", "import matplotlib.pyplot as plt\n", "fig, ax = plt.subplots()\n", "\n", "display = PrecisionRecallDisplay.from_predictions(y_test, predictions_D_proba, name=\"DeepCrystal Test\", ax=ax)\n", "display = PrecisionRecallDisplay.from_predictions(y_test_B, predictions_B_proba, name=\"Balanced Test\", ax=ax)\n", "display = PrecisionRecallDisplay.from_predictions(y_test_S, predictions_S_proba, name=\"SP Test\", ax=ax)\n", "display = PrecisionRecallDisplay.from_predictions(y_test_T, predictions_T_proba, name=\"TR Test\", ax=ax)\n", "\n", "_ = display.ax_.set_title(\"Precision-Recall curve - ESMCrystal_t6_8M_v1\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "d90cc358-7f51-4c2e-962f-7f4069ffc694", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 5 }