{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "AjIEWbCDkanS" }, "source": [ "## Check GPU Availability" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "f4YxNfdGFfhu", "outputId": "2f8c24b0-59f6-466a-ac7a-d3cbdd4d3097" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Thu Apr 25 08:41:08 2024 \n", "+---------------------------------------------------------------------------------------+\n", "| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |\n", "|-----------------------------------------+----------------------+----------------------+\n", "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n", "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n", "| | | MIG M. |\n", "|=========================================+======================+======================|\n", "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", "| N/A 68C P8 12W / 70W | 0MiB / 15360MiB | 0% Default |\n", "| | | N/A |\n", "+-----------------------------------------+----------------------+----------------------+\n", " \n", "+---------------------------------------------------------------------------------------+\n", "| Processes: |\n", "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=======================================================================================|\n", "| No running processes found |\n", "+---------------------------------------------------------------------------------------+\n" ] } ], "source": [ "!nvidia-smi" ] }, { "cell_type": "markdown", "metadata": { "id": "DUwSLgpOkiFy" }, "source": [ "## Install required libraries" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FmkS5lxikB5p", "outputId": "b29dfa0b-2189-43a6-95d2-22848c816e88" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m245.2/245.2 kB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m297.6/297.6 kB\u001b[0m \u001b[31m20.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m72.9/72.9 kB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m542.0/542.0 kB\u001b[0m \u001b[31m29.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m102.0/102.0 kB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m15.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m22.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m15.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m388.9/388.9 kB\u001b[0m \u001b[31m30.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m486.2/486.2 kB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.6/92.6 MB\u001b[0m \u001b[31m9.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.6/44.6 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m65.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.5/110.5 kB\u001b[0m \u001b[31m13.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.3/207.3 kB\u001b[0m \u001b[31m23.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m267.1/267.1 kB\u001b[0m \u001b[31m29.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m8.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m16.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.3/134.3 kB\u001b[0m \u001b[31m15.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", " Building wheel for transformers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n" ] } ], "source": [ "!pip install trl accelerate peft==0.4.0 -Uqqq\n", "!pip install datasets==2.13.1 bitsandbytes==0.41.3 einops==0.7.0 wandb==0.15.8 -Uqqq\n", "!pip install git+https://github.com/huggingface/transformers -Uqqq" ] }, { "cell_type": "markdown", "metadata": { "id": "toxLYDhem1xE" }, "source": [ "## Importing libraries" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "cfveWI2gdztX" }, "outputs": [], "source": [ "import torch\n", "from datasets import load_dataset\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, GenerationConfig\n", "from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel, prepare_model_for_kbit_training\n", "from trl import SFTTrainer\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 113, "referenced_widgets": [ "e0df7902801c4dcdad3bbed886c4dfce", "401ffa071b4c419abda0505b23ca4ce6", "ea2178eb7a1d4467af8f613283f9730c", "208f4d2fc96a40138e286583b7a56602", "1d1a63df50bb47ffb5f36c918ae00978", "1dc8084935274931bc80343f957758f8", "97d8176b7fd0401c9385ae685397aaa7", "a2ed6172a5614f99a132cd9f36857838", "f83f461197f148c1bc0a67e6213d0230", "53c3dd1163ed4c0daf0adbc387eec7bd", "55135d394c96428b90339ec521740b8a", "60be617ce4674fba90dbb038ddedbfda", "dd38503bad7542a999bc888a17db7f44", "d011d3b13a8f4343bd41e67564476135", "c4a4012f1c1147e3bc4c2377d1295ff5", "07e6321ecec743379f8ea789667e1f82", "2fe9b18591d84d63a9529363c75833cc", "cae9a6f61bb9481392c7a4c3d6d34423", "d004b192f64e45518a048110123ff2d8", "d56cb5cb29834580ba7b2480c0246d3a", "b373249aa3a24cc38b6b90b25dd722f6", "db385b8020e040e4bac2ceb1f3671888", "ab590350e0d346edba69d11ae80606e2", "135063f3c04e4118b949fdbb4be24711", "d66baaae2d6849d99fc6e47bcd47bd28", "05093394bc4540dea9f5b834400b89b0", "362878cff74e49b9adb7ddd92640940d", "1b92fc16173c401d99de9025e06157b2", "347bb878a02048ceab99a8d3a4545aa4" ] }, "id": "lfznv6tCmS9t", "outputId": "72aa8a9f-07e9-4b5a-bfaa-8db98c795d8c" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "VBox(children=(HTML(value='
Step | \n", "Training Loss | \n", "
---|---|
10 | \n", "1.670900 | \n", "
20 | \n", "1.314300 | \n", "
30 | \n", "0.964100 | \n", "
40 | \n", "0.533800 | \n", "
50 | \n", "0.219800 | \n", "
60 | \n", "0.081400 | \n", "
70 | \n", "0.047600 | \n", "
"
]
},
"metadata": {}
}
],
"source": [
"peft_model.config.use_cache = False\n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hBtLstduL6LI"
},
"outputs": [],
"source": [
"trainer.push_to_hub(\"therapx\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gsxub1fgUMxt"
},
"source": [
"## Inference Pipeline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VOKiDZlK6DHO"
},
"outputs": [],
"source": [
"# Loading original model\n",
"model_name = \"ybelkada/falcon-7b-sharded-bf16\"\n",
"\n",
"bnb_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_compute_dtype=torch.float16,\n",
")\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_name,\n",
" quantization_config=bnb_config,\n",
" device_map=\"auto\",\n",
" trust_remote_code=True,\n",
")\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
"tokenizer.pad_token = tokenizer.eos_token"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qnP9C3zZ5pt1"
},
"outputs": [],
"source": [
"# Loading PEFT model\n",
"PEFT_MODEL = \"heliosbrahma/falcon-7b-sharded-bf16-finetuned-mental-health-conversational\"\n",
"\n",
"config = PeftConfig.from_pretrained(PEFT_MODEL)\n",
"peft_base_model = AutoModelForCausalLM.from_pretrained(\n",
" config.base_model_name_or_path,\n",
" return_dict=True,\n",
" quantization_config=bnb_config,\n",
" device_map=\"auto\",\n",
" trust_remote_code=True,\n",
")\n",
"\n",
"peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL)\n",
"\n",
"peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
"peft_tokenizer.pad_token = peft_tokenizer.eos_token"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TfvBS3yl54lV"
},
"outputs": [],
"source": [
"# Function to generate responses from both original model and PEFT model and compare their answers.\n",
"def generate_answer(query):\n",
" system_prompt = \"\"\"Answer the following question truthfully.\n",
" If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'.\n",
" If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'.\"\"\"\n",
"\n",
" user_prompt = f\"\"\"
Copy a token from your Hugging Face\ntokens page and paste it below.
Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file.