File size: 22,222 Bytes
bcffb9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting trl\n",
" Downloading trl-0.8.0-py3-none-any.whl.metadata (11 kB)\n",
"Requirement already satisfied: torch>=1.4.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from trl) (2.2.1)\n",
"Requirement already satisfied: transformers>=4.31.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from trl) (4.38.2)\n",
"Requirement already satisfied: numpy>=1.18.2 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from trl) (1.26.3)\n",
"Requirement already satisfied: accelerate in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from trl) (0.27.2)\n",
"Requirement already satisfied: datasets in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from trl) (2.18.0)\n",
"Collecting tyro>=0.5.11 (from trl)\n",
" Downloading tyro-0.7.3-py3-none-any.whl.metadata (7.7 kB)\n",
"Requirement already satisfied: filelock in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (3.13.1)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (4.10.0)\n",
"Requirement already satisfied: sympy in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (1.12)\n",
"Requirement already satisfied: networkx in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (3.1.3)\n",
"Requirement already satisfied: fsspec in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (2024.2.0)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (0.21.3)\n",
"Requirement already satisfied: packaging>=20.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (23.2)\n",
"Requirement already satisfied: pyyaml>=5.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (6.0.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (2023.12.25)\n",
"Requirement already satisfied: requests in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (2.31.0)\n",
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (0.15.2)\n",
"Requirement already satisfied: safetensors>=0.4.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (0.4.2)\n",
"Requirement already satisfied: tqdm>=4.27 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (4.66.1)\n",
"Collecting docstring-parser>=0.14.1 (from tyro>=0.5.11->trl)\n",
" Downloading docstring_parser-0.16-py3-none-any.whl.metadata (3.0 kB)\n",
"Requirement already satisfied: rich>=11.1.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from tyro>=0.5.11->trl) (13.7.0)\n",
"Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl)\n",
" Downloading shtab-1.7.1-py3-none-any.whl.metadata (7.3 kB)\n",
"Requirement already satisfied: psutil in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from accelerate->trl) (5.9.8)\n",
"Requirement already satisfied: pyarrow>=12.0.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (15.0.1)\n",
"Requirement already satisfied: pyarrow-hotfix in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (0.6)\n",
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (0.3.8)\n",
"Requirement already satisfied: pandas in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (2.2.1)\n",
"Requirement already satisfied: xxhash in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (3.4.1)\n",
"Requirement already satisfied: multiprocess in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (0.70.16)\n",
"Requirement already satisfied: aiohttp in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (3.9.3)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from aiohttp->datasets->trl) (1.3.1)\n",
"Requirement already satisfied: attrs>=17.3.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from aiohttp->datasets->trl) (23.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from aiohttp->datasets->trl) (1.4.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from aiohttp->datasets->trl) (6.0.5)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from aiohttp->datasets->trl) (1.9.4)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from requests->transformers>=4.31.0->trl) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from requests->transformers>=4.31.0->trl) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from requests->transformers>=4.31.0->trl) (2.2.1)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from requests->transformers>=4.31.0->trl) (2024.2.2)\n",
"Requirement already satisfied: markdown-it-py>=2.2.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n",
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.17.2)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from jinja2->torch>=1.4.0->trl) (2.1.5)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from pandas->datasets->trl) (2.9.0.post0)\n",
"Requirement already satisfied: pytz>=2020.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from pandas->datasets->trl) (2024.1)\n",
"Requirement already satisfied: tzdata>=2022.7 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from pandas->datasets->trl) (2024.1)\n",
"Requirement already satisfied: mpmath>=0.19 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from sympy->torch>=1.4.0->trl) (1.3.0)\n",
"Requirement already satisfied: mdurl~=0.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n",
"Requirement already satisfied: six>=1.5 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas->datasets->trl) (1.16.0)\n",
"Downloading trl-0.8.0-py3-none-any.whl (224 kB)\n",
"\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m225.0/225.0 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading tyro-0.7.3-py3-none-any.whl (79 kB)\n",
"\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m79.8/79.8 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading docstring_parser-0.16-py3-none-any.whl (36 kB)\n",
"Downloading shtab-1.7.1-py3-none-any.whl (14 kB)\n",
"Installing collected packages: shtab, docstring-parser, tyro, trl\n",
"Successfully installed docstring-parser-0.16 shtab-1.7.1 trl-0.8.0 tyro-0.7.3\n"
]
}
],
"source": [
"!pip install trl"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"'NoneType' object has no attribute 'cadam32bit_grad_fp32'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/bitsandbytes/cextension.py:34: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.\n",
" warn(\"The installed version of bitsandbytes was compiled without GPU support. \"\n"
]
},
{
"ename": "RuntimeError",
"evalue": "Failed to import trl.trainer.sft_trainer because of the following error (look up to see its traceback):\ncannot import name 'prepare_model_for_kbit_training' from 'peft' (/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/peft/__init__.py)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/trl/import_utils.py:172\u001b[0m, in \u001b[0;36m_LazyModule._get_module\u001b[0;34m(self, module_name)\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 172\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mimportlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimport_module\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m.\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mmodule_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__name__\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n",
"File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/importlib/__init__.py:90\u001b[0m, in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m 89\u001b[0m level \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 90\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_bootstrap\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_gcd_import\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m[\u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpackage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m<frozen importlib._bootstrap>:1387\u001b[0m, in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n",
"File \u001b[0;32m<frozen importlib._bootstrap>:1360\u001b[0m, in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n",
"File \u001b[0;32m<frozen importlib._bootstrap>:1331\u001b[0m, in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n",
"File \u001b[0;32m<frozen importlib._bootstrap>:935\u001b[0m, in \u001b[0;36m_load_unlocked\u001b[0;34m(spec)\u001b[0m\n",
"File \u001b[0;32m<frozen importlib._bootstrap_external>:994\u001b[0m, in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n",
"File \u001b[0;32m<frozen importlib._bootstrap>:488\u001b[0m, in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
"File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/trl/trainer/sft_trainer.py:53\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_peft_available():\n\u001b[0;32m---> 53\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpeft\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mSFTTrainer\u001b[39;00m(Trainer):\n",
"\u001b[0;31mImportError\u001b[0m: cannot import name 'prepare_model_for_kbit_training' from 'peft' (/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/peft/__init__.py)",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[3], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtransformers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m AutoModelForCausalLM, AutoTokenizer\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m load_dataset\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtrl\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m SFTTrainer, DataCollatorForCompletionOnlyLM\n\u001b[1;32m 5\u001b[0m dataset \u001b[38;5;241m=\u001b[39m load_dataset(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlucasmccabe-lmi/CodeAlpaca-20k\u001b[39m\u001b[38;5;124m\"\u001b[39m, split\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 7\u001b[0m model \u001b[38;5;241m=\u001b[39m AutoModelForCausalLM\u001b[38;5;241m.\u001b[39mfrom_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfacebook/opt-350m\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m<frozen importlib._bootstrap>:1412\u001b[0m, in \u001b[0;36m_handle_fromlist\u001b[0;34m(module, fromlist, import_, recursive)\u001b[0m\n",
"File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/trl/import_utils.py:163\u001b[0m, in \u001b[0;36m_LazyModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_class_to_module\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[1;32m 162\u001b[0m module \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_module(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_class_to_module[name])\n\u001b[0;32m--> 163\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mmodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodule \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m has no attribute \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/trl/import_utils.py:162\u001b[0m, in \u001b[0;36m_LazyModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 160\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_module(name)\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_class_to_module\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[0;32m--> 162\u001b[0m module \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_module\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_class_to_module\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 163\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(module, name)\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
"File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/trl/import_utils.py:174\u001b[0m, in \u001b[0;36m_LazyModule._get_module\u001b[0;34m(self, module_name)\u001b[0m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m importlib\u001b[38;5;241m.\u001b[39mimport_module(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m module_name, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m)\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m--> 174\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 175\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to import \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodule_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m because of the following error (look up to see its\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 176\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m traceback):\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 177\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n",
"\u001b[0;31mRuntimeError\u001b[0m: Failed to import trl.trainer.sft_trainer because of the following error (look up to see its traceback):\ncannot import name 'prepare_model_for_kbit_training' from 'peft' (/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/peft/__init__.py)"
]
}
],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"from datasets import load_dataset\n",
"from trl import SFTTrainer, DataCollatorForCompletionOnlyLM\n",
"\n",
"dataset = load_dataset(\"lucasmccabe-lmi/CodeAlpaca-20k\", split=\"train\")\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\"facebook/opt-350m\")\n",
"tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-350m\")\n",
"\n",
"def formatting_prompts_func(example):\n",
" output_texts = []\n",
" for i in range(len(example['instruction'])):\n",
" text = f\"### Question: {example['instruction'][i]}\\n ### Answer: {example['output'][i]}\"\n",
" output_texts.append(text)\n",
" return output_texts\n",
"\n",
"response_template = \" ### Answer:\"\n",
"collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)\n",
"\n",
"trainer = SFTTrainer(\n",
" model,\n",
" train_dataset=dataset,\n",
" formatting_func=formatting_prompts_func,\n",
" data_collator=collator,\n",
")\n",
"\n",
"trainer.train()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|