{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "eBpjBBZc6IvA" }, "source": [ "# Fatima Fellowship Coding Challenge: Finetune a Generative AI Model\n", "\n", "Thank you for applying to the Fatima Fellowship. To help us select the Fellows and assess your ability to do machine learning research, we are asking that you complete a short coding challenge.\n", "\n", "**How to submit**: Please make a copy of this colab notebook, add your code and results, and submit your colab notebook along with your application. If you have never used a colab notebook, [check out this video](https://www.youtube.com/watch?v=i-HnvsehuSw)" ] }, { "cell_type": "markdown", "metadata": { "id": "lQNUZjvuRt-m" }, "source": [ "\n", "\n", "---\n", "\n", "\n", "### **Important**: Beore you get started, please make sure to make a **copy of this notebook** and set sharing permissions so that **anyone with the link can view**. Otherwise, we will NOT be able to assess your application.\n", "\n", "\n", "\n", "---\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "qpFKMaPeggyX" }, "source": [ "# 0. Description\n", "\n", "The purpose of this coding challenge is to finetune a generative AI model on a dataset that *you* build.\n", "\n", "The dataset can be of any kind! For example, you could collect a dataset of football jerserys and train a machine learning model to be able to generate jerseys different teams apart. Or, you could finetune a generation model to be able to generate accurate recipes about a particular dish specific to your cuisine.\n", "\n", "We are interested in learning more about you and your coding abilities through this short exercise." ] }, { "cell_type": "markdown", "metadata": { "id": "braBzmRpMe7_" }, "source": [ "# 1. Build a Dataset Based on Your Interests" ] }, { "cell_type": "markdown", "metadata": { "id": "1IWw-NZf5WfF" }, "source": [ "In the first step, you'll be building your OWN dataset of any kind. We expect that many students might build this dataset by scraping the web e.g. Google Images, or extracting samples from existing datasets (e.g. [from Hugging Face](https://huggingface.co/datasets)). Some suggestions:\n", "\n", "* Dataset size: although this can very, we generally recommend that the dataset should have at least 100 (training and validation) samples.\n", "* Dataset diversity: make sure your dataset is sufficiently varied. For example, if your dataset consists of celebrity images, you probably want celebrities of different ages, ethnicities, genders, etc.\n", "\n", "You may find Python libraries that download images such as `google_images_download` useful.\n", "\n", "Once you have built your dataset, please upload it to Hugging Face Hub using the `datasets` library and include the link below:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-07-14T23:54:01.063480Z", "iopub.status.busy": "2024-07-14T23:54:01.063126Z", "iopub.status.idle": "2024-07-14T23:54:14.512034Z", "shell.execute_reply": "2024-07-14T23:54:14.510890Z", "shell.execute_reply.started": "2024-07-14T23:54:01.063450Z" }, "id": "K2GJaYBpw91T", "outputId": "3c4ec6d3-2a83-4a81-8243-29d1cd77612a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: datasets in /opt/conda/lib/python3.10/site-packages (2.20.0)\n", "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from datasets) (3.13.1)\n", "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from datasets) (1.26.4)\n", "Requirement already satisfied: pyarrow>=15.0.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (16.1.0)\n", "Requirement already satisfied: pyarrow-hotfix in /opt/conda/lib/python3.10/site-packages (from datasets) (0.6)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: pandas in /opt/conda/lib/python3.10/site-packages (from datasets) (2.2.2)\n", "Requirement already satisfied: requests>=2.32.2 in /opt/conda/lib/python3.10/site-packages (from datasets) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.66.3 in /opt/conda/lib/python3.10/site-packages (from datasets) (4.66.4)\n", "Requirement already satisfied: xxhash in /opt/conda/lib/python3.10/site-packages (from datasets) (3.4.1)\n", "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.10/site-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.5.0,>=2023.1.0 in /opt/conda/lib/python3.10/site-packages (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets) (2024.5.0)\n", "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.10/site-packages (from datasets) (3.9.1)\n", "Requirement already satisfied: huggingface-hub>=0.21.2 in /opt/conda/lib/python3.10/site-packages (from datasets) (0.23.4)\n", "Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from datasets) (21.3)\n", "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from datasets) (6.0.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (23.2.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (6.0.4)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (1.9.3)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.2->datasets) (4.9.0)\n", "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging->datasets) (3.1.1)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests>=2.32.2->datasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.32.2->datasets) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.32.2->datasets) (1.26.18)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.32.2->datasets) (2024.7.4)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2023.3.post1)\n", "Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2023.4)\n", "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n" ] } ], "source": [ "!pip install datasets" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-07-14T23:54:14.514248Z", "iopub.status.busy": "2024-07-14T23:54:14.513956Z", "iopub.status.idle": "2024-07-14T23:54:31.621227Z", "shell.execute_reply": "2024-07-14T23:54:31.620455Z", "shell.execute_reply.started": "2024-07-14T23:54:14.514223Z" }, "id": "Z0x_qADbxA5x", "outputId": "344c2357-410f-49b0-8761-3f118615a95d" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-07-14 23:54:22.010940: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-07-14 23:54:22.011038: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-07-14 23:54:22.120615: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "from datasets import load_dataset, DatasetDict, Dataset\n", "\n", "from transformers import (\n", " AutoTokenizer,\n", " AutoConfig,\n", " AutoModelForSequenceClassification,\n", " DataCollatorWithPadding,\n", " TrainingArguments,\n", " Trainer)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-07-14T23:54:31.622913Z", "iopub.status.busy": "2024-07-14T23:54:31.622369Z", "iopub.status.idle": "2024-07-14T23:54:45.202850Z", "shell.execute_reply": "2024-07-14T23:54:45.201709Z", "shell.execute_reply.started": "2024-07-14T23:54:31.622886Z" }, "id": "yFUGXvOCxKzX", "outputId": "d7d91cc7-d663-4f0d-9f79-6acd8850d92e", "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", " pid, fd = os.forkpty()\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting peft\n", " Downloading peft-0.11.1-py3-none-any.whl.metadata (13 kB)\n", "Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (2.1.2)\n", "Collecting evaluate\n", " Downloading evaluate-0.4.2-py3-none-any.whl.metadata (9.3 kB)\n", "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from peft) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from peft) (21.3)\n", "Requirement already satisfied: psutil in /opt/conda/lib/python3.10/site-packages (from peft) (5.9.3)\n", "Requirement already satisfied: pyyaml in /opt/conda/lib/python3.10/site-packages (from peft) (6.0.1)\n", "Requirement already satisfied: transformers in /opt/conda/lib/python3.10/site-packages (from peft) (4.42.3)\n", "Requirement already satisfied: tqdm in /opt/conda/lib/python3.10/site-packages (from peft) (4.66.4)\n", "Requirement already satisfied: accelerate>=0.21.0 in /opt/conda/lib/python3.10/site-packages (from peft) (0.32.1)\n", "Requirement already satisfied: safetensors in /opt/conda/lib/python3.10/site-packages (from peft) (0.4.3)\n", "Requirement already satisfied: huggingface-hub>=0.17.0 in /opt/conda/lib/python3.10/site-packages (from peft) (0.23.4)\n", "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch) (3.13.1)\n", "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.10/site-packages (from torch) (4.9.0)\n", "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch) (1.13.0)\n", "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch) (3.2.1)\n", "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch) (3.1.2)\n", "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch) (2024.5.0)\n", "Requirement already satisfied: datasets>=2.0.0 in /opt/conda/lib/python3.10/site-packages (from evaluate) (2.20.0)\n", "Requirement already satisfied: dill in /opt/conda/lib/python3.10/site-packages (from evaluate) (0.3.8)\n", "Requirement already satisfied: pandas in /opt/conda/lib/python3.10/site-packages (from evaluate) (2.2.2)\n", "Requirement already satisfied: requests>=2.19.0 in /opt/conda/lib/python3.10/site-packages (from evaluate) (2.32.3)\n", "Requirement already satisfied: xxhash in /opt/conda/lib/python3.10/site-packages (from evaluate) (3.4.1)\n", "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.10/site-packages (from evaluate) (0.70.16)\n", "Requirement already satisfied: pyarrow>=15.0.0 in /opt/conda/lib/python3.10/site-packages (from datasets>=2.0.0->evaluate) (16.1.0)\n", "Requirement already satisfied: pyarrow-hotfix in /opt/conda/lib/python3.10/site-packages (from datasets>=2.0.0->evaluate) (0.6)\n", "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.10/site-packages (from datasets>=2.0.0->evaluate) (3.9.1)\n", "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging>=20.0->peft) (3.1.1)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->evaluate) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->evaluate) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->evaluate) (1.26.18)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->evaluate) (2024.7.4)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch) (2.1.3)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.10/site-packages (from pandas->evaluate) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas->evaluate) (2023.3.post1)\n", "Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.10/site-packages (from pandas->evaluate) (2023.4)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n", "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from transformers->peft) (2023.12.25)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /opt/conda/lib/python3.10/site-packages (from transformers->peft) (0.19.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->evaluate) (23.2.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->evaluate) (6.0.4)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.9.3)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.4.1)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.1)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->evaluate) (4.0.3)\n", "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->evaluate) (1.16.0)\n", "Downloading peft-0.11.1-py3-none-any.whl (251 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.6/251.6 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0mm\n", "\u001b[?25hDownloading evaluate-0.4.2-py3-none-any.whl (84 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.1/84.1 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: peft, evaluate\n", "Successfully installed evaluate-0.4.2 peft-0.11.1\n" ] } ], "source": [ "!pip install peft evaluate" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-07-14T23:54:45.205611Z", "iopub.status.busy": "2024-07-14T23:54:45.205289Z", "iopub.status.idle": "2024-07-14T23:54:48.418952Z", "shell.execute_reply": "2024-07-14T23:54:48.418056Z", "shell.execute_reply.started": "2024-07-14T23:54:45.205583Z" }, "id": "CLX2Z20zxpc9", "outputId": "56192048-fdd7-4291-8c3f-bc201b54a214" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6b02c7d5863a45c89f774634dfec4e02", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading readme: 0%| | 0.00/592 [00:00, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type='SEQ_CLS', inference_mode=False, r=4, target_modules={'q_lin'}, lora_alpha=32, lora_dropout=0.01, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False, layer_replication=None)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "peft_config" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-07-14T23:54:55.251220Z", "iopub.status.busy": "2024-07-14T23:54:55.250714Z", "iopub.status.idle": "2024-07-14T23:54:55.274109Z", "shell.execute_reply": "2024-07-14T23:54:55.273307Z", "shell.execute_reply.started": "2024-07-14T23:54:55.251193Z" }, "id": "lGdxTfFv0I-9", "outputId": "987f0000-697c-458c-942c-f19af4e74181" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trainable params: 628,994 || all params: 67,584,004 || trainable%: 0.9307\n" ] } ], "source": [ "model = get_peft_model(model, peft_config)\n", "model.print_trainable_parameters()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2024-07-14T23:54:55.275504Z", "iopub.status.busy": "2024-07-14T23:54:55.275234Z", "iopub.status.idle": "2024-07-14T23:54:55.279915Z", "shell.execute_reply": "2024-07-14T23:54:55.279014Z", "shell.execute_reply.started": "2024-07-14T23:54:55.275481Z" }, "id": "XrMxqk_90S6Q" }, "outputs": [], "source": [ "# hyperparameters\n", "lr = 1e-3\n", "batch_size = 4\n", "num_epochs = 10" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-07-14T23:54:55.281084Z", "iopub.status.busy": "2024-07-14T23:54:55.280845Z", "iopub.status.idle": "2024-07-14T23:54:55.350814Z", "shell.execute_reply": "2024-07-14T23:54:55.349830Z", "shell.execute_reply.started": "2024-07-14T23:54:55.281063Z" }, "id": "kYJDZJt50S9Q", "outputId": "c7e9d13f-d31d-4724-9638-7d7351bd3d7b" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/transformers/training_args.py:1494: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n" ] } ], "source": [ "# define training arguments\n", "training_args = TrainingArguments(\n", " output_dir= model_checkpoint + \"-lora-text-classification\",\n", " learning_rate=lr,\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " num_train_epochs=num_epochs,\n", " weight_decay=0.01,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " load_best_model_at_end=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 75 }, "execution": { "iopub.execute_input": "2024-07-14T23:54:55.352902Z", "iopub.status.busy": "2024-07-14T23:54:55.352113Z", "iopub.status.idle": "2024-07-15T00:01:58.470190Z", "shell.execute_reply": "2024-07-15T00:01:58.469378Z", "shell.execute_reply.started": "2024-07-14T23:54:55.352863Z" }, "id": "0V25wLEN0fTk", "outputId": "830beb8b-cf8d-4b0e-f759-ffd32e3836be" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)\n", "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:" ] }, { "name": "stdin", "output_type": "stream", "text": [ " ········································\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.17.4" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /kaggle/working/wandb/run-20240714_235539-ypukib7l" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run distilbert-base-uncased-lora-text-classification to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/johanneseboigbe55-octave-analytics/huggingface" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/johanneseboigbe55-octave-analytics/huggingface/runs/ypukib7l" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [1250/1250 06:01, Epoch 10/10]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracy
1No log0.259819{'accuracy': 0.896}
2No log0.357995{'accuracy': 0.888}
3No log0.403524{'accuracy': 0.885}
40.2622000.513255{'accuracy': 0.881}
50.2622000.614568{'accuracy': 0.886}
60.2622000.757634{'accuracy': 0.885}
70.2622000.749906{'accuracy': 0.885}
80.0450000.808175{'accuracy': 0.891}
90.0450000.804488{'accuracy': 0.89}
100.0450000.800608{'accuracy': 0.893}

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Trainer is attempting to log a value of \"{'accuracy': 0.896}\" of type for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n", "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "Trainer is attempting to log a value of \"{'accuracy': 0.888}\" of type for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n", "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "Trainer is attempting to log a value of \"{'accuracy': 0.885}\" of type for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n", "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "Trainer is attempting to log a value of \"{'accuracy': 0.881}\" of type for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n", "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "Trainer is attempting to log a value of \"{'accuracy': 0.886}\" of type for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n", "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "Trainer is attempting to log a value of \"{'accuracy': 0.885}\" of type for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n", "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "Trainer is attempting to log a value of \"{'accuracy': 0.885}\" of type for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n", "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "Trainer is attempting to log a value of \"{'accuracy': 0.891}\" of type for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n", "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "Trainer is attempting to log a value of \"{'accuracy': 0.89}\" of type for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n", "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "Trainer is attempting to log a value of \"{'accuracy': 0.893}\" of type for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n" ] }, { "data": { "text/plain": [ "TrainOutput(global_step=1250, training_loss=0.12382552947998048, metrics={'train_runtime': 421.8436, 'train_samples_per_second': 23.705, 'train_steps_per_second': 2.963, 'total_flos': 1253694805157184.0, 'train_loss': 0.12382552947998048, 'epoch': 10.0})" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# creater trainer object\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_dataset[\"train\"],\n", " eval_dataset=tokenized_dataset[\"validation\"],\n", " tokenizer=tokenizer,\n", " data_collator=data_collator, # this will dynamically pad examples in each batch to be equal length\n", " compute_metrics=compute_metrics,\n", ")\n", "\n", "# train model\n", "trainer.train()" ] }, { "cell_type": "markdown", "metadata": { "id": "5lbssHsP0prm" }, "source": [ "### Generate prediction" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2024-07-15T00:01:58.471810Z", "iopub.status.busy": "2024-07-15T00:01:58.471432Z", "iopub.status.idle": "2024-07-15T00:02:14.131264Z", "shell.execute_reply": "2024-07-15T00:02:14.130128Z", "shell.execute_reply.started": "2024-07-15T00:01:58.471777Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", " pid, fd = os.forkpty()\n", "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Found existing installation: torch 2.1.2\n", "Uninstalling torch-2.1.2:\n", " Successfully uninstalled torch-2.1.2\n" ] } ], "source": [ "!pip uninstall -y torch\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2024-07-15T00:02:14.135181Z", "iopub.status.busy": "2024-07-15T00:02:14.134858Z", "iopub.status.idle": "2024-07-15T00:06:34.130741Z", "shell.execute_reply": "2024-07-15T00:06:34.129542Z", "shell.execute_reply.started": "2024-07-15T00:02:14.135154Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/nightly/gpu\n", "Collecting torch\n", " Downloading torch-2.3.1-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)\n", "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch) (3.13.1)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch) (4.9.0)\n", "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch) (1.13.0)\n", "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch) (3.2.1)\n", "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch) (3.1.2)\n", "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch) (2024.5.0)\n", "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)\n", " Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)\n", " Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)\n", " Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)\n", " Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)\n", " Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)\n", " Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-curand-cu12==10.3.2.106 (from torch)\n", " Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch)\n", " Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch)\n", " Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-nccl-cu12==2.20.5 (from torch)\n", " Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)\n", "Collecting nvidia-nvtx-cu12==12.1.105 (from torch)\n", " Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.7 kB)\n", "Collecting triton==2.3.1 (from torch)\n", " Downloading triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)\n", "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch)\n", " Downloading nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch) (2.1.3)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n", "Downloading torch-2.3.1-cp310-cp310-manylinux1_x86_64.whl (779.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m779.1/779.1 MB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m46.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m39.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m731.7/731.7 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:02\u001b[0m\n", "\u001b[?25hDownloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:02\u001b[0m\n", "\u001b[?25hDownloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m176.2/176.2 MB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n", "\u001b[?25hDownloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (168.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.1/168.1 MB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl (21.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.3/21.3 MB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hInstalling collected packages: triton, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch\n", "Successfully installed nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.5.82 nvidia-nvtx-cu12-12.1.105 torch-2.3.1 triton-2.3.1\n" ] } ], "source": [ "!pip install torch --pre --extra-index-url https://download.pytorch.org/whl/nightly/gpu\n", "\n" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2024-07-15T00:06:34.133020Z", "iopub.status.busy": "2024-07-15T00:06:34.132587Z", "iopub.status.idle": "2024-07-15T00:06:34.142491Z", "shell.execute_reply": "2024-07-15T00:06:34.140853Z", "shell.execute_reply.started": "2024-07-15T00:06:34.132982Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MPS backend is not available\n" ] } ], "source": [ "import torch\n", "\n", "if torch.backends.mps.is_available():\n", " print(\"MPS backend is available\")\n", "else:\n", " print(\"MPS backend is not available\")\n" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2024-07-15T00:09:13.790787Z", "iopub.status.busy": "2024-07-15T00:09:13.790051Z", "iopub.status.idle": "2024-07-15T00:09:13.797350Z", "shell.execute_reply": "2024-07-15T00:09:13.796206Z", "shell.execute_reply.started": "2024-07-15T00:09:13.790754Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] } ], "source": [ "import torch\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using device: {device}\")\n" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "execution": { "iopub.execute_input": "2024-07-15T00:09:37.751116Z", "iopub.status.busy": "2024-07-15T00:09:37.750202Z", "iopub.status.idle": "2024-07-15T00:09:37.880663Z", "shell.execute_reply": "2024-07-15T00:09:37.879602Z", "shell.execute_reply.started": "2024-07-15T00:09:37.751070Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Trained model predictions:\n", "--------------------------\n", "Text: It was good.\n", "Prediction: tensor([1], device='cuda:0')\n", "\n", "Text: Not a fan, don't recommed.\n", "Prediction: tensor([0], device='cuda:0')\n", "\n", "Text: Better than the first one.\n", "Prediction: tensor([1], device='cuda:0')\n", "\n", "Text: This is not worth watching even once.\n", "Prediction: tensor([1], device='cuda:0')\n", "\n", "Text: This one is a pass.\n", "Prediction: tensor([0], device='cuda:0')\n", "\n" ] } ], "source": [ "model.to(device) # move the model to the appropriate device\n", "\n", "print(\"Trained model predictions:\")\n", "print(\"--------------------------\")\n", "\n", "for text in text_list:\n", " inputs = tokenizer.encode(text, return_tensors=\"pt\").to(device) # move the inputs to the appropriate device\n", " logits = model(inputs).logits\n", " predictions = torch.max(logits, 1).indices\n", " print(f\"Text: {text}\\nPrediction: {predictions}\\n\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Y0iH6E6-0z-r" }, "source": [ "### Optional: push model to hub" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "execution": { "iopub.execute_input": "2024-07-15T00:10:50.355689Z", "iopub.status.busy": "2024-07-15T00:10:50.354890Z", "iopub.status.idle": "2024-07-15T00:10:50.386851Z", "shell.execute_reply": "2024-07-15T00:10:50.385997Z", "shell.execute_reply.started": "2024-07-15T00:10:50.355623Z" }, "id": "2I3YvXHo06c6" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fca9611543134ed49ba804e1472e7b45", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='