{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting trl\n", " Downloading trl-0.8.0-py3-none-any.whl.metadata (11 kB)\n", "Requirement already satisfied: torch>=1.4.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from trl) (2.2.1)\n", "Requirement already satisfied: transformers>=4.31.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from trl) (4.38.2)\n", "Requirement already satisfied: numpy>=1.18.2 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from trl) (1.26.3)\n", "Requirement already satisfied: accelerate in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from trl) (0.27.2)\n", "Requirement already satisfied: datasets in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from trl) (2.18.0)\n", "Collecting tyro>=0.5.11 (from trl)\n", " Downloading tyro-0.7.3-py3-none-any.whl.metadata (7.7 kB)\n", "Requirement already satisfied: filelock in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (3.13.1)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (4.10.0)\n", "Requirement already satisfied: sympy in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (1.12)\n", "Requirement already satisfied: networkx in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (3.2.1)\n", "Requirement already satisfied: jinja2 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (3.1.3)\n", "Requirement already satisfied: fsspec in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (2024.2.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (0.21.3)\n", "Requirement already satisfied: packaging>=20.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (2023.12.25)\n", "Requirement already satisfied: requests in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (2.31.0)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (0.15.2)\n", "Requirement already satisfied: safetensors>=0.4.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (0.4.2)\n", "Requirement already satisfied: tqdm>=4.27 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (4.66.1)\n", "Collecting docstring-parser>=0.14.1 (from tyro>=0.5.11->trl)\n", " Downloading docstring_parser-0.16-py3-none-any.whl.metadata (3.0 kB)\n", "Requirement already satisfied: rich>=11.1.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from tyro>=0.5.11->trl) (13.7.0)\n", "Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl)\n", " Downloading shtab-1.7.1-py3-none-any.whl.metadata (7.3 kB)\n", "Requirement already satisfied: psutil in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from accelerate->trl) (5.9.8)\n", "Requirement already satisfied: pyarrow>=12.0.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (15.0.1)\n", "Requirement already satisfied: pyarrow-hotfix in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (0.6)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (0.3.8)\n", "Requirement already satisfied: pandas in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (2.2.1)\n", "Requirement already satisfied: xxhash in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (3.4.1)\n", "Requirement already satisfied: multiprocess in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (0.70.16)\n", "Requirement already satisfied: aiohttp in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (3.9.3)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from aiohttp->datasets->trl) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from aiohttp->datasets->trl) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from aiohttp->datasets->trl) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from aiohttp->datasets->trl) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from aiohttp->datasets->trl) (1.9.4)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from requests->transformers>=4.31.0->trl) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from requests->transformers>=4.31.0->trl) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from requests->transformers>=4.31.0->trl) (2.2.1)\n", "Requirement already satisfied: certifi>=2017.4.17 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from requests->transformers>=4.31.0->trl) (2024.2.2)\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.17.2)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from jinja2->torch>=1.4.0->trl) (2.1.5)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from pandas->datasets->trl) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from pandas->datasets->trl) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.7 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from pandas->datasets->trl) (2024.1)\n", "Requirement already satisfied: mpmath>=0.19 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from sympy->torch>=1.4.0->trl) (1.3.0)\n", "Requirement already satisfied: mdurl~=0.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n", "Requirement already satisfied: six>=1.5 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas->datasets->trl) (1.16.0)\n", "Downloading trl-0.8.0-py3-none-any.whl (224 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m225.0/225.0 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading tyro-0.7.3-py3-none-any.whl (79 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.8/79.8 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading docstring_parser-0.16-py3-none-any.whl (36 kB)\n", "Downloading shtab-1.7.1-py3-none-any.whl (14 kB)\n", "Installing collected packages: shtab, docstring-parser, tyro, trl\n", "Successfully installed docstring-parser-0.16 shtab-1.7.1 trl-0.8.0 tyro-0.7.3\n" ] } ], "source": [ "!pip install trl" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "'NoneType' object has no attribute 'cadam32bit_grad_fp32'\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/bitsandbytes/cextension.py:34: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.\n", " warn(\"The installed version of bitsandbytes was compiled without GPU support. \"\n" ] }, { "ename": "RuntimeError", "evalue": "Failed to import trl.trainer.sft_trainer because of the following error (look up to see its traceback):\ncannot import name 'prepare_model_for_kbit_training' from 'peft' (/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/peft/__init__.py)", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/trl/import_utils.py:172\u001b[0m, in \u001b[0;36m_LazyModule._get_module\u001b[0;34m(self, module_name)\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 172\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mimportlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimport_module\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m.\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mmodule_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__name__\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/importlib/__init__.py:90\u001b[0m, in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m 89\u001b[0m level \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 90\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_bootstrap\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_gcd_import\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m[\u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpackage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m:1387\u001b[0m, in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n", "File \u001b[0;32m:1360\u001b[0m, in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n", "File \u001b[0;32m:1331\u001b[0m, in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n", "File \u001b[0;32m:935\u001b[0m, in \u001b[0;36m_load_unlocked\u001b[0;34m(spec)\u001b[0m\n", "File \u001b[0;32m:994\u001b[0m, in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n", "File \u001b[0;32m:488\u001b[0m, in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/trl/trainer/sft_trainer.py:53\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_peft_available():\n\u001b[0;32m---> 53\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpeft\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mSFTTrainer\u001b[39;00m(Trainer):\n", "\u001b[0;31mImportError\u001b[0m: cannot import name 'prepare_model_for_kbit_training' from 'peft' (/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/peft/__init__.py)", "\nThe above exception was the direct cause of the following exception:\n", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[3], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtransformers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m AutoModelForCausalLM, AutoTokenizer\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m load_dataset\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtrl\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m SFTTrainer, DataCollatorForCompletionOnlyLM\n\u001b[1;32m 5\u001b[0m dataset \u001b[38;5;241m=\u001b[39m load_dataset(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlucasmccabe-lmi/CodeAlpaca-20k\u001b[39m\u001b[38;5;124m\"\u001b[39m, split\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 7\u001b[0m model \u001b[38;5;241m=\u001b[39m AutoModelForCausalLM\u001b[38;5;241m.\u001b[39mfrom_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfacebook/opt-350m\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[0;32m:1412\u001b[0m, in \u001b[0;36m_handle_fromlist\u001b[0;34m(module, fromlist, import_, recursive)\u001b[0m\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/trl/import_utils.py:163\u001b[0m, in \u001b[0;36m_LazyModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_class_to_module\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[1;32m 162\u001b[0m module \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_module(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_class_to_module[name])\n\u001b[0;32m--> 163\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mmodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodule \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m has no attribute \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/trl/import_utils.py:162\u001b[0m, in \u001b[0;36m_LazyModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 160\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_module(name)\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_class_to_module\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[0;32m--> 162\u001b[0m module \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_module\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_class_to_module\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 163\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(module, name)\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/trl/import_utils.py:174\u001b[0m, in \u001b[0;36m_LazyModule._get_module\u001b[0;34m(self, module_name)\u001b[0m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m importlib\u001b[38;5;241m.\u001b[39mimport_module(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m module_name, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m)\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m--> 174\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 175\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to import \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodule_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m because of the following error (look up to see its\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 176\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m traceback):\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 177\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n", "\u001b[0;31mRuntimeError\u001b[0m: Failed to import trl.trainer.sft_trainer because of the following error (look up to see its traceback):\ncannot import name 'prepare_model_for_kbit_training' from 'peft' (/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/peft/__init__.py)" ] } ], "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "from datasets import load_dataset\n", "from trl import SFTTrainer, DataCollatorForCompletionOnlyLM\n", "\n", "dataset = load_dataset(\"lucasmccabe-lmi/CodeAlpaca-20k\", split=\"train\")\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\"facebook/opt-350m\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-350m\")\n", "\n", "def formatting_prompts_func(example):\n", " output_texts = []\n", " for i in range(len(example['instruction'])):\n", " text = f\"### Question: {example['instruction'][i]}\\n ### Answer: {example['output'][i]}\"\n", " output_texts.append(text)\n", " return output_texts\n", "\n", "response_template = \" ### Answer:\"\n", "collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)\n", "\n", "trainer = SFTTrainer(\n", " model,\n", " train_dataset=dataset,\n", " formatting_func=formatting_prompts_func,\n", " data_collator=collator,\n", ")\n", "\n", "trainer.train()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" } }, "nbformat": 4, "nbformat_minor": 2 }