File size: 22,222 Bytes
bcffb9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting trl\n",
      "  Downloading trl-0.8.0-py3-none-any.whl.metadata (11 kB)\n",
      "Requirement already satisfied: torch>=1.4.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from trl) (2.2.1)\n",
      "Requirement already satisfied: transformers>=4.31.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from trl) (4.38.2)\n",
      "Requirement already satisfied: numpy>=1.18.2 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from trl) (1.26.3)\n",
      "Requirement already satisfied: accelerate in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from trl) (0.27.2)\n",
      "Requirement already satisfied: datasets in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from trl) (2.18.0)\n",
      "Collecting tyro>=0.5.11 (from trl)\n",
      "  Downloading tyro-0.7.3-py3-none-any.whl.metadata (7.7 kB)\n",
      "Requirement already satisfied: filelock in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (3.13.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (4.10.0)\n",
      "Requirement already satisfied: sympy in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (1.12)\n",
      "Requirement already satisfied: networkx in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (3.2.1)\n",
      "Requirement already satisfied: jinja2 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (3.1.3)\n",
      "Requirement already satisfied: fsspec in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from torch>=1.4.0->trl) (2024.2.0)\n",
      "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (0.21.3)\n",
      "Requirement already satisfied: packaging>=20.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (23.2)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (6.0.1)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (2023.12.25)\n",
      "Requirement already satisfied: requests in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (2.31.0)\n",
      "Requirement already satisfied: tokenizers<0.19,>=0.14 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (0.15.2)\n",
      "Requirement already satisfied: safetensors>=0.4.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (0.4.2)\n",
      "Requirement already satisfied: tqdm>=4.27 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from transformers>=4.31.0->trl) (4.66.1)\n",
      "Collecting docstring-parser>=0.14.1 (from tyro>=0.5.11->trl)\n",
      "  Downloading docstring_parser-0.16-py3-none-any.whl.metadata (3.0 kB)\n",
      "Requirement already satisfied: rich>=11.1.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from tyro>=0.5.11->trl) (13.7.0)\n",
      "Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl)\n",
      "  Downloading shtab-1.7.1-py3-none-any.whl.metadata (7.3 kB)\n",
      "Requirement already satisfied: psutil in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from accelerate->trl) (5.9.8)\n",
      "Requirement already satisfied: pyarrow>=12.0.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (15.0.1)\n",
      "Requirement already satisfied: pyarrow-hotfix in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (0.6)\n",
      "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (0.3.8)\n",
      "Requirement already satisfied: pandas in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (2.2.1)\n",
      "Requirement already satisfied: xxhash in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (3.4.1)\n",
      "Requirement already satisfied: multiprocess in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (0.70.16)\n",
      "Requirement already satisfied: aiohttp in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from datasets->trl) (3.9.3)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from aiohttp->datasets->trl) (1.3.1)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from aiohttp->datasets->trl) (23.2.0)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from aiohttp->datasets->trl) (1.4.1)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from aiohttp->datasets->trl) (6.0.5)\n",
      "Requirement already satisfied: yarl<2.0,>=1.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from aiohttp->datasets->trl) (1.9.4)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from requests->transformers>=4.31.0->trl) (3.3.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from requests->transformers>=4.31.0->trl) (3.6)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from requests->transformers>=4.31.0->trl) (2.2.1)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from requests->transformers>=4.31.0->trl) (2024.2.2)\n",
      "Requirement already satisfied: markdown-it-py>=2.2.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n",
      "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.17.2)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from jinja2->torch>=1.4.0->trl) (2.1.5)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from pandas->datasets->trl) (2.9.0.post0)\n",
      "Requirement already satisfied: pytz>=2020.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from pandas->datasets->trl) (2024.1)\n",
      "Requirement already satisfied: tzdata>=2022.7 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from pandas->datasets->trl) (2024.1)\n",
      "Requirement already satisfied: mpmath>=0.19 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from sympy->torch>=1.4.0->trl) (1.3.0)\n",
      "Requirement already satisfied: mdurl~=0.1 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n",
      "Requirement already satisfied: six>=1.5 in /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas->datasets->trl) (1.16.0)\n",
      "Downloading trl-0.8.0-py3-none-any.whl (224 kB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m225.0/225.0 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hDownloading tyro-0.7.3-py3-none-any.whl (79 kB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.8/79.8 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hDownloading docstring_parser-0.16-py3-none-any.whl (36 kB)\n",
      "Downloading shtab-1.7.1-py3-none-any.whl (14 kB)\n",
      "Installing collected packages: shtab, docstring-parser, tyro, trl\n",
      "Successfully installed docstring-parser-0.16 shtab-1.7.1 trl-0.8.0 tyro-0.7.3\n"
     ]
    }
   ],
   "source": [
    "!pip install trl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'NoneType' object has no attribute 'cadam32bit_grad_fp32'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/bitsandbytes/cextension.py:34: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.\n",
      "  warn(\"The installed version of bitsandbytes was compiled without GPU support. \"\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Failed to import trl.trainer.sft_trainer because of the following error (look up to see its traceback):\ncannot import name 'prepare_model_for_kbit_training' from 'peft' (/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/peft/__init__.py)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/trl/import_utils.py:172\u001b[0m, in \u001b[0;36m_LazyModule._get_module\u001b[0;34m(self, module_name)\u001b[0m\n\u001b[1;32m    171\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 172\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mimportlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimport_module\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m.\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mmodule_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__name__\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m    173\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n",
      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/importlib/__init__.py:90\u001b[0m, in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m     89\u001b[0m         level \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 90\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_bootstrap\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_gcd_import\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m[\u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpackage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:1387\u001b[0m, in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:1360\u001b[0m, in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:1331\u001b[0m, in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:935\u001b[0m, in \u001b[0;36m_load_unlocked\u001b[0;34m(spec)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap_external>:994\u001b[0m, in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:488\u001b[0m, in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/trl/trainer/sft_trainer.py:53\u001b[0m\n\u001b[1;32m     52\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_peft_available():\n\u001b[0;32m---> 53\u001b[0m     \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpeft\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training\n\u001b[1;32m     56\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mSFTTrainer\u001b[39;00m(Trainer):\n",
      "\u001b[0;31mImportError\u001b[0m: cannot import name 'prepare_model_for_kbit_training' from 'peft' (/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/peft/__init__.py)",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[3], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtransformers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m AutoModelForCausalLM, AutoTokenizer\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m load_dataset\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtrl\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m SFTTrainer, DataCollatorForCompletionOnlyLM\n\u001b[1;32m      5\u001b[0m dataset \u001b[38;5;241m=\u001b[39m load_dataset(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlucasmccabe-lmi/CodeAlpaca-20k\u001b[39m\u001b[38;5;124m\"\u001b[39m, split\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m      7\u001b[0m model \u001b[38;5;241m=\u001b[39m AutoModelForCausalLM\u001b[38;5;241m.\u001b[39mfrom_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfacebook/opt-350m\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:1412\u001b[0m, in \u001b[0;36m_handle_fromlist\u001b[0;34m(module, fromlist, import_, recursive)\u001b[0m\n",
      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/trl/import_utils.py:163\u001b[0m, in \u001b[0;36m_LazyModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m    161\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_class_to_module\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[1;32m    162\u001b[0m     module \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_module(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_class_to_module[name])\n\u001b[0;32m--> 163\u001b[0m     value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mmodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    165\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodule \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m has no attribute \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/trl/import_utils.py:162\u001b[0m, in \u001b[0;36m_LazyModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m    160\u001b[0m     value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_module(name)\n\u001b[1;32m    161\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_class_to_module\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[0;32m--> 162\u001b[0m     module \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_module\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_class_to_module\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    163\u001b[0m     value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(module, name)\n\u001b[1;32m    164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/trl/import_utils.py:174\u001b[0m, in \u001b[0;36m_LazyModule._get_module\u001b[0;34m(self, module_name)\u001b[0m\n\u001b[1;32m    172\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m importlib\u001b[38;5;241m.\u001b[39mimport_module(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m module_name, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m)\n\u001b[1;32m    173\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m--> 174\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m    175\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to import \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodule_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m because of the following error (look up to see its\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    176\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m traceback):\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    177\u001b[0m     ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Failed to import trl.trainer.sft_trainer because of the following error (look up to see its traceback):\ncannot import name 'prepare_model_for_kbit_training' from 'peft' (/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/peft/__init__.py)"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from datasets import load_dataset\n",
    "from trl import SFTTrainer, DataCollatorForCompletionOnlyLM\n",
    "\n",
    "dataset = load_dataset(\"lucasmccabe-lmi/CodeAlpaca-20k\", split=\"train\")\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\"facebook/opt-350m\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-350m\")\n",
    "\n",
    "def formatting_prompts_func(example):\n",
    "    output_texts = []\n",
    "    for i in range(len(example['instruction'])):\n",
    "        text = f\"### Question: {example['instruction'][i]}\\n ### Answer: {example['output'][i]}\"\n",
    "        output_texts.append(text)\n",
    "    return output_texts\n",
    "\n",
    "response_template = \" ### Answer:\"\n",
    "collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)\n",
    "\n",
    "trainer = SFTTrainer(\n",
    "    model,\n",
    "    train_dataset=dataset,\n",
    "    formatting_func=formatting_prompts_func,\n",
    "    data_collator=collator,\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}