{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "3890292a-c99e-4367-955d-5883b93dba36", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mRequirement already satisfied: flash-attn in /opt/conda/lib/python3.10/site-packages (2.5.9.post1)\n", "Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (from flash-attn) (2.2.0)\n", "Requirement already satisfied: einops in /opt/conda/lib/python3.10/site-packages (from flash-attn) (0.8.0)\n", "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.13.1)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (4.9.0)\n", "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (1.12)\n", "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1)\n", "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1.2)\n", "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (2023.12.2)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch->flash-attn) (2.1.3)\n", "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch->flash-attn) (1.3.0)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install -q peft transformers datasets huggingface_hub\n", "!pip install flash-attn --no-build-isolation" ] }, { "cell_type": "code", "execution_count": 2, "id": "f1cc378f-afb6-441f-a4c6-2ec427b4cd4b", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n", "from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType\n", "import torch\n", "from datasets import load_dataset\n", "import os\n", "from torch.utils.data import DataLoader\n", "from tqdm import tqdm\n", "from huggingface_hub import notebook_login\n", "from huggingface_hub import HfApi" ] }, { "cell_type": "code", "execution_count": null, "id": "e4ab50d7-a4c9-4246-acd8-8875b87fe0da", "metadata": {}, "outputs": [], "source": [ "notebook_login()" ] }, { "cell_type": "code", "execution_count": 25, "id": "8a1cb1f9-b89d-4cac-a595-44e1e0ef85b2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CommitInfo(commit_url='https://huggingface.co/Granther/prompt-tuned-phi3/commit/7ea57da9a4eccf3794c58bb4317df1c97a0fe2c8', commit_message='Upload prompt_tune_phi3.ipynb with huggingface_hub', commit_description='', oid='7ea57da9a4eccf3794c58bb4317df1c97a0fe2c8', pr_url=None, pr_revision=None, pr_num=None)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "api = HfApi()\n", "api.upload_file(path_or_fileobj='prompt_tune_phi3.ipynb',\n", " path_in_repo='prompt_tune_phi3.ipynb',\n", " repo_id='Granther/prompt-tuned-phi3',\n", " repo_type='model'\n", " )" ] }, { "cell_type": "code", "execution_count": 54, "id": "6cad1e5c-038f-4e75-8c3f-8ce0a43713a4", "metadata": {}, "outputs": [], "source": [ "device = 'cuda'\n", "\n", "model_id = 'microsoft/Phi-3-mini-128k-instruct'\n", "\n", "peft_conf = PromptTuningConfig(\n", " peft_type=PeftType.PROMPT_TUNING, # what kind of peft\n", " task_type=TaskType.CAUSAL_LM, # config task\n", " prompt_tuning_init=PromptTuningInit.TEXT, # Set to 'TEXT' to use prompt_tuning_init_text\n", " num_virtual_tokens=8, # x times the number of hidden transformer layers\n", " prompt_tuning_init_text=\"Classify if the tweet is a complaint or not:\",\n", " tokenizer_name_or_path=model_id\n", ")\n", "\n", "dataset_name = \"twitter_complaints\"\n", "checkpoint_name = f\"{dataset_name}_{model_id}_{peft_conf.peft_type}_{peft_conf.task_type}_v1.pt\".replace(\n", " \"/\", \"_\"\n", ")\n", "\n", "text_col = 'Tweet text'\n", "label_col = 'text_label'\n", "max_len = 64\n", "lr = 3e-2\n", "epochs = 5\n", "batch_size = 8" ] }, { "cell_type": "code", "execution_count": 28, "id": "6f677839-ef23-428a-bcfe-f596590804ca", "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset('ought/raft', dataset_name)" ] }, { "cell_type": "code", "execution_count": 30, "id": "c0c05613-7941-4959-ada9-49ed1093bec4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['Unlabeled', 'complaint', 'no complaint']" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset['train'].features['Label'].names\n", "#>>> ['Unlabeled', 'complaint', 'no complaint']" ] }, { "cell_type": "code", "execution_count": 32, "id": "14e2bc8b-b4e3-49c9-ae2b-5946e412caa5", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "11da1eb81527428a95c41816f5bf459f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map (num_proc=10): 0%| | 0/3399 [00:00>> -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000\n", " # Pad the beginning of the sequence with n -100s (ignore tokens)\n", " model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n", "\n", " for i in range(batch_size):\n", " sample_input_ids = model_inputs[\"input_ids\"][i]\n", " label_input_ids = labels[\"input_ids\"][i]\n", " model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (target_max_len - len(sample_input_ids)) + sample_input_ids\n", " model_inputs[\"attention_mask\"][i] = [0] * (target_max_len - len(sample_input_ids)) + model_inputs[\"attention_mask\"][i]\n", " labels[\"input_ids\"][i] = [-100] * (target_max_len - len(sample_input_ids)) + label_input_ids\n", " model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:target_max_len])\n", " model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:target_max_len])\n", " labels[\"input_ids\"][i] = torch.tensor(labels[\"input_ids\"][i][:target_max_len])\n", " model_inputs[\"labels\"] = labels[\"input_ids\"]\n", " return model_inputs" ] }, { "cell_type": "code", "execution_count": 33, "id": "72ddca5f-7bce-4342-9414-9dd9d41d9dec", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "05958c1cf67d413b9085622ace0cb799", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Preprocessing dataset (num_proc=10): 0%| | 0/50 [00:00