{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "28e4c4d1-a73f-437b-a1bd-c2cc3874924a" }, "source": [ "# 강의 11주차: llama2-food-order-understanding\n", "\n", "1. llama-2-7b-chat-hf 를 주문 문장 이해에 미세 튜닝\n", "\n", "- food-order-understanding-small-3200.json (학습)\n", "- food-order-understanding-small-800.json (검증)\n", "\n", "\n", "종속적인 필요 내용\n", "- huggingface 계정 설정 및 llama-2 사용 승인\n", "- 로깅을 위한 wandb" ], "id": "28e4c4d1-a73f-437b-a1bd-c2cc3874924a" }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nDZe_wqKU6J3", "outputId": "39634a40-9521-42e0-d048-e3f78b7a8306" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.35.2)\n", "Collecting peft\n", " Downloading peft-0.7.0-py3-none-any.whl (168 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.3/168.3 kB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting accelerate\n", " Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m265.7/265.7 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting optimum\n", " Downloading optimum-1.15.0-py3-none-any.whl (400 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m400.9/400.9 kB\u001b[0m \u001b[31m13.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting bitsandbytes\n", " Downloading bitsandbytes-0.41.3.post1-py3-none-any.whl (92.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.6/92.6 MB\u001b[0m \u001b[31m11.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting trl\n", " Downloading trl-0.7.4-py3-none-any.whl (133 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m133.9/133.9 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting wandb\n", " Downloading wandb-0.16.1-py3-none-any.whl (2.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m88.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.4)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.0)\n", "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.1)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n", "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.1.0+cu118)\n", "Collecting coloredlogs (from optimum)\n", " Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from optimum) (1.12)\n", "Collecting datasets (from optimum)\n", " Downloading datasets-2.15.0-py3-none-any.whl (521 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m521.2/521.2 kB\u001b[0m \u001b[31m54.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting tyro>=0.5.11 (from trl)\n", " Downloading tyro-0.6.0-py3-none-any.whl (100 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.9/100.9 kB\u001b[0m \u001b[31m15.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.7)\n", "Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)\n", " Downloading GitPython-3.1.40-py3-none-any.whl (190 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.6/190.6 kB\u001b[0m \u001b[31m27.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting sentry-sdk>=1.0.0 (from wandb)\n", " Downloading sentry_sdk-1.38.0-py2.py3-none-any.whl (252 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m252.8/252.8 kB\u001b[0m \u001b[31m29.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting docker-pycreds>=0.4.0 (from wandb)\n", " Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n", "Collecting setproctitle (from wandb)\n", " Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)\n", "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (67.7.2)\n", "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.4.4)\n", "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n", "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n", "Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb)\n", " Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (2023.6.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.5.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.11.17)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.2)\n", "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.1.0)\n", "Collecting sentencepiece!=0.1.92,>=0.1.91 (from transformers)\n", " Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m77.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting docstring-parser>=0.14.1 (from tyro>=0.5.11->trl)\n", " Downloading docstring_parser-0.15-py3-none-any.whl (36 kB)\n", "Requirement already satisfied: rich>=11.1.0 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (13.7.0)\n", "Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl)\n", " Downloading shtab-1.6.5-py3-none-any.whl (13 kB)\n", "Collecting humanfriendly>=9.1 (from coloredlogs->optimum)\n", " Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m13.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (9.0.0)\n", "Collecting pyarrow-hotfix (from datasets->optimum)\n", " Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)\n", "Collecting dill<0.3.8,>=0.3.0 (from datasets->optimum)\n", " Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m19.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (1.5.3)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (3.4.1)\n", "Collecting multiprocess (from datasets->optimum)\n", " Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m20.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (3.9.1)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->optimum) (1.3.0)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (23.1.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (6.0.4)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.9.3)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.4.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.3.1)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (4.0.3)\n", "Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb)\n", " Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.16.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.3)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->optimum) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->optimum) (2023.3.post1)\n", "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n", "Installing collected packages: sentencepiece, bitsandbytes, smmap, shtab, setproctitle, sentry-sdk, pyarrow-hotfix, humanfriendly, docstring-parser, docker-pycreds, dill, multiprocess, gitdb, coloredlogs, tyro, GitPython, accelerate, wandb, datasets, trl, peft, optimum\n", "Successfully installed GitPython-3.1.40 accelerate-0.25.0 bitsandbytes-0.41.3.post1 coloredlogs-15.0.1 datasets-2.15.0 dill-0.3.7 docker-pycreds-0.4.0 docstring-parser-0.15 gitdb-4.0.11 humanfriendly-10.0 multiprocess-0.70.15 optimum-1.15.0 peft-0.7.0 pyarrow-hotfix-0.6 sentencepiece-0.1.99 sentry-sdk-1.38.0 setproctitle-1.3.3 shtab-1.6.5 smmap-5.0.1 trl-0.7.4 tyro-0.6.0 wandb-0.16.1\n" ] } ], "source": [ "pip install transformers peft accelerate optimum bitsandbytes trl wandb" ], "id": "nDZe_wqKU6J3" }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "51eb00d7-2928-41ad-9ae9-7f0da7d64d6d", "outputId": "e8550fbb-e4fc-4b70-fb40-1babe806f850" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n", " warnings.warn(\n" ] } ], "source": [ "import os\n", "from dataclasses import dataclass, field\n", "from typing import Optional\n", "import re\n", "\n", "import torch\n", "import tyro\n", "from accelerate import Accelerator\n", "from datasets import load_dataset, Dataset\n", "from peft import AutoPeftModelForCausalLM, LoraConfig\n", "from tqdm import tqdm\n", "from transformers import (\n", " AutoModelForCausalLM,\n", " AutoTokenizer,\n", " BitsAndBytesConfig,\n", " TrainingArguments,\n", ")\n", "\n", "from trl import SFTTrainer\n", "\n", "from trl.trainer import ConstantLengthDataset" ], "id": "51eb00d7-2928-41ad-9ae9-7f0da7d64d6d" }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 145, "referenced_widgets": [ "99532ff414024109b47d52f85d12756c", "c9c40f2edc544032bd50d8020f8aa5b4", "c2bdbf8108174891a74270ba291a3519", "adac7366f35c47bd8f608d8a3986cbe7", "a0cc6856d4e947cb90001e1679abe061", "e4c56fa7facb42cdb400172737f5b66b", "939dae7eac1143c0851249ab6b239eaa", "a9c5cc0353944b2eb8649f08c51ddd26", "7b8e04ff45a34ad38e50c3f71acf31b8", "31bcd6aba0be45fd8a5d69a1cee4d1e2", "c2ed791223104173b44419665b6b611e", "ae10fc8777724a5781ad4665f868cb87", "b51df9eb8d70473b9eb828f988bc9430", "34d6b8f3945c46e7830f54f5ea30a2ae", "c558553d832347d2aef2de23bdea73ac", "3843b559ef054100bc6e4d0666ba70fe", "6d34bb28660d40b3b58a454d85e9684b", "9b2fff53449641a88210178ced5ffa3f", "1261f868f75342e394503bbd9ca8305d", "bf4a5a4f9a294a679f6725d72520d0f6", "e8983362fd0046a99608eed2d0f42f92", "9f0d306f3bdc4cc9a4cee4dc8f10ee5b", "f971e68bef0440918129e8ed6d882bb7", "ce73f650db034858890e20452f74ddaa", "ca4c88b41aea474ca29fe7b25774dae6", "798c89074789407e86f87578a3ec0cc7", "abf493394d644ed18b57066d214b13d0", "4a52e47448584335b9d143b0a96b9095", "459bd9737cbb43e898d7fa9e2f341417", "c5fa2300fc664eb0a36526b87ff53055", "2604feda7a204e6e8e17dd0ccd8649b2", "db4161ded47c4c09ab50c0a431a2831a" ] }, "id": "tX7gYxZaVhYL", "outputId": "ff182a8d-bd90-47cc-d957-2ecde7eeb44a" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "VBox(children=(HTML(value='
/content/wandb/run-20231211_111154-7fn0vhfg
"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Step | \n", "Training Loss | \n", "
---|---|
100 | \n", "0.805100 | \n", "
200 | \n", "0.371400 | \n", "
300 | \n", "0.337600 | \n", "
400 | \n", "0.328100 | \n", "
500 | \n", "0.317000 | \n", "
600 | \n", "0.314700 | \n", "
700 | \n", "0.312500 | \n", "
800 | \n", "0.296200 | \n", "
"
]
},
"metadata": {}
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TrainOutput(global_step=852, training_loss=0.3792879995605755, metrics={'train_runtime': 3082.7229, 'train_samples_per_second': 1.038, 'train_steps_per_second': 0.519, 'total_flos': 3.469201579494605e+16, 'train_loss': 0.3792879995605755, 'epoch': 0.53})"
]
},
"metadata": {},
"execution_count": 26
}
],
"source": [
"trainer.train()"
],
"id": "14019fa9-0c6f-4729-ac99-0d407af375b8"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Y4FQSyRghQt",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 36
},
"outputId": "09ff0896-0972-4f07-8a8c-89786f3c7149"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'/gdrive/MyDrive/nlp/lora-llama-2-7b-food-order-understanding'"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
}
},
"metadata": {},
"execution_count": 27
}
],
"source": [
"script_args.training_args.output_dir"
],
"id": "3Y4FQSyRghQt"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "49f05450-da2a-4edd-9db2-63836a0ec73a"
},
"outputs": [],
"source": [
"trainer.save_model(script_args.training_args.output_dir)"
],
"id": "49f05450-da2a-4edd-9db2-63836a0ec73a"
},
{
"cell_type": "markdown",
"metadata": {
"id": "652f307e-e1d7-43ae-b083-dba2d94c2296"
},
"source": [
"# 추론 테스트"
],
"id": "652f307e-e1d7-43ae-b083-dba2d94c2296"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ea8a1fea-7499-4386-9dea-0509110f61af"
},
"outputs": [],
"source": [
"from transformers import pipeline, TextStreamer"
],
"id": "ea8a1fea-7499-4386-9dea-0509110f61af"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "52626888-1f6e-46b6-a8dd-836622149ff5"
},
"outputs": [],
"source": [
"instruction_prompt_template = \"\"\"###System;다음은 매장에서 고객이 음식을 주문하는 주문 문장이다. 이를 분석하여 음식명, 옵션명, 수량을 추출하여 고객의 의도를 이해하고자 한다.\n",
"분석 결과를 완성해주기 바란다.\n",
"\n",
"### 주문 문장: {0} ### 분석 결과:\n",
"\"\"\"\n",
"\n",
"prompt_template = \"\"\"###System;{System}\n",
"###User;{User}\n",
"###Midm;\"\"\"\n",
"\n",
"default_system_msg = (\n",
" \"너는 먼저 사용자가 입력한 주문 문장을 분석하는 에이전트이다. 이로부터 주문을 구성하는 음식명, 옵션명, 수량을 차례대로 추출해야 한다.\"\n",
")"
],
"id": "52626888-1f6e-46b6-a8dd-836622149ff5"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "46e844fa-8f63-4359-a4fb-df66e8171796"
},
"outputs": [],
"source": [
"evaluation_queries = [\n",
" \"오늘은 비가오니깐 이거 먹자. 삼선짬뽕 곱배기 하나하구요, 사천 탕수육 중짜 한그릇 주세요.\",\n",
" \"아이스아메리카노 톨사이즈 한잔 하고요. 딸기스무디 한잔 주세요. 또, 콜드브루라떼 하나요.\",\n",
" \"참이슬 한병, 코카콜라 1.5리터 한병, 테슬라 한병이요.\",\n",
" \"꼬막무침 1인분하고요, 닭도리탕 중자 주세요. 그리고 소주도 한병 주세요.\",\n",
" \"김치찌개 3인분하고요, 계란말이 주세요.\",\n",
" \"불고기버거세트 1개하고요 감자튀김 추가해주세요.\",\n",
" \"불닭볶음면 1개랑 사리곰탕면 2개 주세요.\",\n",
" \"카페라떼 아이스 샷추가 한잔하구요. 스콘 하나 주세요\",\n",
" \"여기요 춘천닭갈비 4인분하고요. 라면사리 추가하겠습니다. 콜라 300ml 두캔주세요.\",\n",
" \"있잖아요 조랭이떡국 3인분하고요. 떡만두 한세트 주세요.\",\n",
" \"깐풍탕수 2인분 하고요 콜라 1.5리터 한병이요.\",\n",
"]"
],
"id": "46e844fa-8f63-4359-a4fb-df66e8171796"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1919cf1f-482e-4185-9d06-e3cea1918416"
},
"outputs": [],
"source": [
"def wrapper_generate(model, input_prompt):\n",
" data = tokenizer(input_prompt, return_tensors=\"pt\")\n",
" streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n",
" input_ids = data.input_ids[..., :-1]\n",
" with torch.no_grad():\n",
" pred = model.generate(\n",
" input_ids=input_ids.cuda(),\n",
" streamer=streamer,\n",
" use_cache=True,\n",
" max_new_tokens=float('inf'),\n",
" temperature=0.5\n",
" )\n",
" decoded_text = tokenizer.batch_decode(pred, skip_special_tokens=True)\n",
" return (decoded_text[0][len(input_prompt):])"
],
"id": "1919cf1f-482e-4185-9d06-e3cea1918416"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eaac1f6f-c823-4488-8edb-2f931ddf0daa",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "13193a97-1b7b-4aed-90a1-b6c0e5a0a15b"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
";- 분석 결과 0: 음식명:삼선짬뽕,옵션:곱배기,수량:하나\n",
"- 분석 결과 1: 음식명:사천 탕수육,옵션:중짜,수량:한그릇\n",
";- 분석 결과 0: 음식명:아이스아메리카노,옵션:톨,수량:한잔\n",
"- 분석 결과 1: 음식명:딸기스무디,수량:한잔\n",
"- 분석 결과 2: 음식명:콜드브루라떼,수량:하나\n",
";- 분석 결과 0: 음식명:참이슬, 수량:한병\n",
"- 분석 결과 1: 음식명:코카콜라, 옵션:1.5리터, 수량:한병\n",
"- 분석 결과 2: 음식명:테슬라, 수량:한병\n",
";- 분석 결과 0: 음식명:꼬막무침,수량:1인분\n",
"- 분석 결과 1: 음식명:닭도리탕,옵션:중자\n",
"- 분석 결과 2: 음식명:소주,수량:한병\n",
";- 분석 결과 0: 음식명:김치찌개,수량:3인분\n",
"- 분석 결과 1: 음식명:계란말이\n",
";- 분석 결과 0: 음식명:불고기버거세트,수량:1개\n",
"- 분석 결과 1: 음식명:감자튀김\n",
";- 분석 결과 0: 음식명:불닭볶음면,수량:1개\n",
"- 분석 결과 1: 음식명:사리곰탕면,수량:2개\n",
";- 분석 결과 0: 음식명:카페라떼, 옵션:아이스 샷추가, 수량:한잔\n",
"- 분석 결과 1: 음식명:스콘, 수량:하나\n",
";- 분석 결과 0: 음식명:춘천닭갈비,수량:4인분\n",
"- 분석 결과 1: 음식명:라면사리\n",
"- 분석 결과 2: 음식명:콜라,옵션:300ml,수량:두캔\n",
";- 분석 결과 0: 음식명:조랭이떡국, 수량:3인분\n",
"- 분석 결과 1: 음식명:떡만두, 수량:한세트\n",
";- 분석 결과 0: 음식명:깐풍탕수,수량:2인분\n",
"- 분석 결과 1: 음식명:콜라,옵션:1.5리터,수량:한병\n"
]
}
],
"source": [
"eval_dic = {i:wrapper_generate(model=base_model, input_prompt=prompt_template.format(System=default_system_msg, User=evaluation_queries[i]))for i, query in enumerate(evaluation_queries)}"
],
"id": "eaac1f6f-c823-4488-8edb-2f931ddf0daa"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fefd04ba-2ed8-4f84-bdd0-86d52b3f39f6",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "e65f131a-0e86-46be-8671-d88f782b0e04"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"- 분석 결과 0: 음식명:삼선짬뽕,옵션:곱배기,수량:하나\n",
"- 분석 결과 1: 음식명:사천 탕수육,옵션:중짜,수량:한그릇\n"
]
}
],
"source": [
"print(eval_dic[0])"
],
"id": "fefd04ba-2ed8-4f84-bdd0-86d52b3f39f6"
},
{
"cell_type": "markdown",
"metadata": {
"id": "3f471e3a-723b-4df5-aa72-46f571f6bab6"
},
"source": [
"# 미세튜닝된 모델 로딩 후 테스트"
],
"id": "3f471e3a-723b-4df5-aa72-46f571f6bab6"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "a43bdd07-7555-42b2-9888-a614afec892f"
},
"outputs": [],
"source": [
"bnb_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
")"
],
"id": "a43bdd07-7555-42b2-9888-a614afec892f"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "39db2ee4-23c8-471f-89b2-bca34964bf81"
},
"outputs": [],
"source": [
"trained_model = AutoPeftModelForCausalLM.from_pretrained(\n",
" script_args.training_args.output_dir,\n",
" quantization_config=bnb_config,\n",
" device_map=\"auto\",\n",
" cache_dir=script_args.cache_dir\n",
")"
],
"id": "39db2ee4-23c8-471f-89b2-bca34964bf81"
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "KpCEf3DK0K6j"
},
"id": "KpCEf3DK0K6j",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "b0b75ca4-730d-4bde-88bb-a86462a76d52",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 251
},
"outputId": "bfe7f03d-b1c2-4950-cad6-db301ed9fa48"
},
"outputs": [
{
"output_type": "error",
"ename": "NameError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m
Copy a token from your Hugging Face\ntokens page and paste it below.
Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file.