{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "torch.cuda.is_available()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/cuda_setup/main.py:136: UserWarning: /opt/conda did not contain libcudart.so as expected! Searching further paths...\n", " warn(msg)\n", "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n", "The class this function is called from is 'LlamaTokenizer'.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "===================================BUG REPORT===================================\n", "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", "================================================================================\n", "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so\n", "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n", "CUDA SETUP: Detected CUDA version 113\n", "CUDA SETUP: Loading binary /home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0b970b9989854c33aa4d19fa9457d7a2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/33 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from transformers import LlamaTokenizer, LlamaForCausalLM\n", "from peft import prepare_model_for_int8_training\n", "tokenizer = LlamaTokenizer.from_pretrained(\n", " \"decapoda-research/llama-7b-hf\")\n", " \n", "tokenizer.pad_token_id = 0\n", "tokenizer.padding_side = 'left'\n", "\n", "model = LlamaForCausalLM.from_pretrained(\n", " \"decapoda-research/llama-7b-hf\",\n", " load_in_8bit=True,\n", " device_map=\"auto\",\n", " torch_dtype=torch.float16\n", ")\n", "\n", "model = prepare_model_for_int8_training(model)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "table: 2-1137692-1\n", "columns: Entrant,Constructor,Chassis,Engine †,Tyre,Driver,Rounds\n", "Q: What were the rounds on the Engine † of the Ferrari 048?\n", "A: SELECT Rounds FROM 2-1137692-1 WHERE Engine † = 'ferrari 048'\n", "END\n", "\n", "\n", "table: 1-21530474-1\n", "columns: Chassis code,Model no.,Production years,Drivetrain,Transmission,Engine type,Engine code,Region(s)\n", "Q: Name the drivetrain for 1ur-fse for usf41\n", "A: SELECT Drivetrain FROM 1-21530474-1 WHERE Engine code = '1UR-FSE' AND Chassis code = 'USF41'\n", "END\n", "\n", "\n", "table: 2-14155087-1\n", "columns: Callsign,Area served,Frequency,Band,On-air ID,Purpose\n", "Q: What is the Callsign with an Area of tamworth and frequency of 0 88.9?\n", "A: SELECT Callsign FROM 2-14155087-1 WHERE Area served = 'tamworth' AND Frequency = '0 88.9'\n", "END\n", "\n", "\n", "table: 2-17580726-2\n", "columns: Date,Opponent,Venue,Score,Attendance,Scorers\n", "Q: What is the number of people in attendance when Tonbridge Angels is the opponent?\n", "A: SELECT Attendance FROM 2-17580726-2 WHERE Opponent = 'tonbridge angels'\n", "END\n", "\n", "\n", "table: 1-27986200-3\n", "columns: Proceed to Quarter-final,Match points,Aggregate score,Points margin,Eliminated from competition\n", "Q: What were the match points when Bordeaux-Bègles was eliminated from competition? \n", "A: SELECT Match points FROM 1-27986200-3 WHERE Eliminated from competition = 'Bordeaux-Bègles'\n", "END\n", "\n" ] } ], "source": [ "import random\n", "import json\n", "\n", "# defined by WikiSQL\n", "\n", "agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']\n", "cond_ops = ['=', '>', '<', 'OP']\n", "syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']\n", "\n", "def fix_repr(d,cols,types,tid):\n", " sel_index=d['sel'] \n", " agg_index=d['agg']\n", " conditions=d['conds']\n", " col = cols[sel_index]\n", " rep = 'SELECT {agg} {sel} FROM {tid}'.format(\n", " agg=agg_ops[agg_index],\n", " sel=col,\n", " tid=tid\n", " )\n", " if conditions:\n", " cs = []\n", " for i, o, v in conditions:\n", " #print(i,cols)\n", " nm = cols[i]\n", " op = cond_ops[o]\n", " \n", " if types[i] in ['text']:\n", " val = f\"\\'{v}\\'\"\n", " else:\n", " val = v\n", " cs.append(f'{nm} {op} {val}')\n", " #print(cs)\n", "\n", " rep += ' WHERE ' + ' AND '.join(cs)\n", " \n", " return rep\n", "\n", "tbl_cols = {}\n", "tbl_types = {}\n", "tbl_str = {}\n", "\n", "prefix = 'Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.'\n", "\n", "def tbl_def_to_string(id, header, types):\n", " s = f'table: {id}\\ncolumns: ' + ','.join(header)\n", " return s\n", "\n", "with open('data/train.tables.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['id']\n", " hdr = js['header']\n", " ts = js['types']\n", " tbl_str[id] = tbl_def_to_string(id,hdr,ts)\n", " tbl_cols[id] = hdr\n", " tbl_types[id] = ts\n", "\n", "q_s = []\n", "a_s = []\n", "\n", "with open('data/train.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['table_id']\n", " s = tbl_str[id]\n", " qst = js['question']\n", " nl = s + '\\nQ: ' + qst + '\\nA: '\n", " q_s.append(nl)\n", "\n", " sql = js['sql']\n", " a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n", " a = a + \"\\nEND\\n\"\n", " a_s.append(a)\n", "\n", "M = len(q_s)\n", "\n", "data_txt = [q_s[i] + a_s[i] for i in range(M)]\n", "\n", "for i in range(5):\n", " j = random.randint(0,M-1)\n", " print()\n", " print(data_txt[j]) \n", " \n", " " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "toks = [tokenizer(s) for s in data_txt]\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "92\n", " 0\n", "count 56355.000000\n", "mean 101.219519\n", "std 21.740325\n", "min 63.000000\n", "25% 87.500000\n", "50% 97.000000\n", "75% 109.000000\n", "max 461.000000\n", "32084\n" ] } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "print(len(toks[0].input_ids))\n", "lens = np.array([len(tok.input_ids) for tok in toks])\n", "print(pd.DataFrame(lens).describe())\n", "\n", "z = zip(q_s,lens)\n", "q_red = [a for a,b in z if b < 100]\n", "z = zip(a_s,lens)\n", "a_red = [a for a,b in z if b < 100]\n", "\n", "data_red = [q_red[i] + a_red[i] for i in range(len(q_red))]\n", "print(len(data_red))\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "01debeeb68bb40a7b83031e88c5ace1e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/32084 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import random, datasets\n", "#d = {'prompt': random.sample(data_red, 1000)}\n", "d = {'prompt': data_red}\n", "\n", "data = datasets.Dataset.from_dict(d)\n", "data = data.map(lambda x:\n", " tokenizer(\n", " x['prompt'],\n", " truncation=True,\n", " max_length=100,\n", " padding=\"max_length\"\n", " ))\n", "\n", "data = data.remove_columns('prompt')\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from peft import LoraConfig, get_peft_model\n", "import transformers\n", "import datasets\n", "\n", "LORA_R = 4\n", "LORA_ALPHA = 16\n", "LORA_DROPOUT = .1\n", "CUTOFF_LEN = 256\n", "BATCH = 128\n", "MICRO_BATCH = 4\n", "N_GAS = BATCH//MICRO_BATCH\n", "EPOCHS = 1\n", "LR = 1e-4\n", "\n", "lora_cfg = LoraConfig(\n", " r = LORA_R,\n", " lora_alpha=LORA_ALPHA,\n", " lora_dropout=LORA_DROPOUT,\n", " task_type='CASUAL_LM',\n", " target_modules=['q_proj','v_proj']\n", ")\n", "\n", "model = get_peft_model(model,lora_cfg)\n", "\n", "targs = transformers.TrainingArguments(\n", " per_device_train_batch_size=MICRO_BATCH,\n", " gradient_accumulation_steps=N_GAS,\n", " warmup_steps=0,\n", " num_train_epochs=EPOCHS,\n", " learning_rate=LR,\n", " fp16=True,\n", " logging_steps=1,\n", " output_dir='sqllama-out2',\n", " save_total_limit=3,\n", " remove_unused_columns=False\n", ")\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "
---|---|
1 | \n", "2.748800 | \n", "
2 | \n", "2.699100 | \n", "
3 | \n", "2.670200 | \n", "
4 | \n", "2.600500 | \n", "
5 | \n", "2.560100 | \n", "
6 | \n", "2.556800 | \n", "
7 | \n", "2.498100 | \n", "
8 | \n", "2.515400 | \n", "
9 | \n", "2.436100 | \n", "
10 | \n", "2.411700 | \n", "
11 | \n", "2.346400 | \n", "
12 | \n", "2.276300 | \n", "
13 | \n", "2.238000 | \n", "
14 | \n", "2.189100 | \n", "
15 | \n", "2.109200 | \n", "
16 | \n", "2.058000 | \n", "
17 | \n", "1.983900 | \n", "
18 | \n", "1.928600 | \n", "
19 | \n", "1.824100 | \n", "
20 | \n", "1.794700 | \n", "
21 | \n", "1.681200 | \n", "
22 | \n", "1.598900 | \n", "
23 | \n", "1.562000 | \n", "
24 | \n", "1.527200 | \n", "
25 | \n", "1.518700 | \n", "
26 | \n", "1.493100 | \n", "
27 | \n", "1.500500 | \n", "
28 | \n", "1.464000 | \n", "
29 | \n", "1.386900 | \n", "
30 | \n", "1.373400 | \n", "
31 | \n", "1.362200 | \n", "
32 | \n", "1.360800 | \n", "
33 | \n", "1.321000 | \n", "
34 | \n", "1.310500 | \n", "
35 | \n", "1.302600 | \n", "
36 | \n", "1.256100 | \n", "
37 | \n", "1.252500 | \n", "
38 | \n", "1.202300 | \n", "
39 | \n", "1.249100 | \n", "
40 | \n", "1.188600 | \n", "
41 | \n", "1.203200 | \n", "
42 | \n", "1.150000 | \n", "
43 | \n", "1.182000 | \n", "
44 | \n", "1.192300 | \n", "
45 | \n", "1.133100 | \n", "
46 | \n", "1.119600 | \n", "
47 | \n", "1.097000 | \n", "
48 | \n", "1.142100 | \n", "
49 | \n", "1.117200 | \n", "
50 | \n", "1.129200 | \n", "
51 | \n", "1.087300 | \n", "
52 | \n", "1.098700 | \n", "
53 | \n", "1.135400 | \n", "
54 | \n", "1.071700 | \n", "
55 | \n", "1.087300 | \n", "
56 | \n", "1.051400 | \n", "
57 | \n", "1.068300 | \n", "
58 | \n", "1.092500 | \n", "
59 | \n", "1.068600 | \n", "
60 | \n", "1.072800 | \n", "
61 | \n", "1.074000 | \n", "
62 | \n", "1.060400 | \n", "
63 | \n", "1.065800 | \n", "
64 | \n", "1.075900 | \n", "
65 | \n", "1.059500 | \n", "
66 | \n", "1.039600 | \n", "
67 | \n", "1.051400 | \n", "
68 | \n", "1.049500 | \n", "
69 | \n", "1.023800 | \n", "
70 | \n", "1.071900 | \n", "
71 | \n", "1.051000 | \n", "
72 | \n", "1.034700 | \n", "
73 | \n", "1.041600 | \n", "
74 | \n", "1.030900 | \n", "
75 | \n", "1.010800 | \n", "
76 | \n", "1.019800 | \n", "
77 | \n", "1.005000 | \n", "
78 | \n", "1.043800 | \n", "
79 | \n", "1.009200 | \n", "
80 | \n", "1.017100 | \n", "
81 | \n", "1.044600 | \n", "
82 | \n", "1.022600 | \n", "
83 | \n", "1.011400 | \n", "
84 | \n", "0.996600 | \n", "
85 | \n", "1.029900 | \n", "
86 | \n", "0.988200 | \n", "
87 | \n", "1.005600 | \n", "
88 | \n", "0.986600 | \n", "
89 | \n", "1.025300 | \n", "
90 | \n", "1.012500 | \n", "
91 | \n", "0.988100 | \n", "
92 | \n", "1.001800 | \n", "
93 | \n", "0.987100 | \n", "
94 | \n", "1.017600 | \n", "
95 | \n", "0.998500 | \n", "
96 | \n", "0.966600 | \n", "
97 | \n", "0.983700 | \n", "
98 | \n", "0.961800 | \n", "
99 | \n", "0.969000 | \n", "
100 | \n", "0.989200 | \n", "
101 | \n", "0.956400 | \n", "
102 | \n", "0.976000 | \n", "
103 | \n", "1.000100 | \n", "
104 | \n", "1.001500 | \n", "
105 | \n", "0.995900 | \n", "
106 | \n", "0.989700 | \n", "
107 | \n", "0.965700 | \n", "
108 | \n", "0.968400 | \n", "
109 | \n", "1.019600 | \n", "
110 | \n", "1.000100 | \n", "
111 | \n", "0.978500 | \n", "
112 | \n", "0.978900 | \n", "
113 | \n", "0.952600 | \n", "
114 | \n", "0.975400 | \n", "
115 | \n", "0.989400 | \n", "
116 | \n", "0.968500 | \n", "
117 | \n", "0.960100 | \n", "
118 | \n", "0.979100 | \n", "
119 | \n", "0.955100 | \n", "
120 | \n", "0.934800 | \n", "
121 | \n", "0.943600 | \n", "
122 | \n", "0.976700 | \n", "
123 | \n", "0.998700 | \n", "
124 | \n", "0.930500 | \n", "
125 | \n", "0.953500 | \n", "
126 | \n", "0.978000 | \n", "
127 | \n", "0.967300 | \n", "
128 | \n", "0.929400 | \n", "
129 | \n", "0.963100 | \n", "
130 | \n", "0.961500 | \n", "
131 | \n", "0.978500 | \n", "
132 | \n", "0.937200 | \n", "
133 | \n", "0.953400 | \n", "
134 | \n", "0.962000 | \n", "
135 | \n", "0.950700 | \n", "
136 | \n", "0.925100 | \n", "
137 | \n", "0.958800 | \n", "
138 | \n", "0.926200 | \n", "
139 | \n", "0.930600 | \n", "
140 | \n", "0.968900 | \n", "
141 | \n", "0.970400 | \n", "
142 | \n", "0.927100 | \n", "
143 | \n", "0.911800 | \n", "
144 | \n", "0.953200 | \n", "
145 | \n", "0.907100 | \n", "
146 | \n", "0.935900 | \n", "
147 | \n", "0.970600 | \n", "
148 | \n", "0.920400 | \n", "
149 | \n", "0.930200 | \n", "
150 | \n", "0.926700 | \n", "
151 | \n", "0.913400 | \n", "
152 | \n", "0.926800 | \n", "
153 | \n", "0.967200 | \n", "
154 | \n", "0.939500 | \n", "
155 | \n", "0.910600 | \n", "
156 | \n", "0.926400 | \n", "
157 | \n", "0.935400 | \n", "
158 | \n", "0.967700 | \n", "
159 | \n", "0.899000 | \n", "
160 | \n", "0.916600 | \n", "
161 | \n", "0.961600 | \n", "
162 | \n", "0.898200 | \n", "
163 | \n", "0.944600 | \n", "
164 | \n", "0.935700 | \n", "
165 | \n", "0.922500 | \n", "
166 | \n", "0.897600 | \n", "
167 | \n", "0.968600 | \n", "
168 | \n", "0.927400 | \n", "
169 | \n", "0.910900 | \n", "
170 | \n", "0.904700 | \n", "
171 | \n", "0.899800 | \n", "
172 | \n", "0.896400 | \n", "
173 | \n", "0.862100 | \n", "
174 | \n", "0.909100 | \n", "
175 | \n", "0.903200 | \n", "
176 | \n", "0.958600 | \n", "
177 | \n", "0.902500 | \n", "
178 | \n", "0.894900 | \n", "
179 | \n", "0.937900 | \n", "
180 | \n", "0.900700 | \n", "
181 | \n", "0.922300 | \n", "
182 | \n", "0.939300 | \n", "
183 | \n", "0.932600 | \n", "
184 | \n", "0.913300 | \n", "
185 | \n", "0.941700 | \n", "
186 | \n", "0.886300 | \n", "
187 | \n", "0.918000 | \n", "
188 | \n", "0.884000 | \n", "
189 | \n", "0.947400 | \n", "
190 | \n", "0.894500 | \n", "
191 | \n", "0.929300 | \n", "
192 | \n", "0.877300 | \n", "
193 | \n", "0.894300 | \n", "
194 | \n", "0.867800 | \n", "
195 | \n", "0.913500 | \n", "
196 | \n", "0.908100 | \n", "
197 | \n", "0.931200 | \n", "
198 | \n", "0.911000 | \n", "
199 | \n", "0.941800 | \n", "
200 | \n", "0.913000 | \n", "
201 | \n", "0.921800 | \n", "
202 | \n", "0.921700 | \n", "
203 | \n", "0.914500 | \n", "
204 | \n", "0.910500 | \n", "
205 | \n", "0.906600 | \n", "
206 | \n", "0.915100 | \n", "
207 | \n", "0.881600 | \n", "
208 | \n", "0.884700 | \n", "
209 | \n", "0.902900 | \n", "
210 | \n", "0.882600 | \n", "
211 | \n", "0.891000 | \n", "
212 | \n", "0.914400 | \n", "
213 | \n", "0.930400 | \n", "
214 | \n", "0.891100 | \n", "
215 | \n", "0.859300 | \n", "
216 | \n", "0.891800 | \n", "
217 | \n", "0.873000 | \n", "
218 | \n", "0.925900 | \n", "
219 | \n", "0.905700 | \n", "
220 | \n", "0.921200 | \n", "
221 | \n", "0.890200 | \n", "
222 | \n", "0.915800 | \n", "
223 | \n", "0.887300 | \n", "
224 | \n", "0.898300 | \n", "
225 | \n", "0.865600 | \n", "
226 | \n", "0.873900 | \n", "
227 | \n", "0.904800 | \n", "
228 | \n", "0.917900 | \n", "
229 | \n", "0.923400 | \n", "
230 | \n", "0.939700 | \n", "
231 | \n", "0.913400 | \n", "
232 | \n", "0.873100 | \n", "
233 | \n", "0.896700 | \n", "
234 | \n", "0.892100 | \n", "
235 | \n", "0.902100 | \n", "
236 | \n", "0.927200 | \n", "
237 | \n", "0.912900 | \n", "
238 | \n", "0.872900 | \n", "
239 | \n", "0.904700 | \n", "
240 | \n", "0.879600 | \n", "
241 | \n", "0.879800 | \n", "
242 | \n", "0.908800 | \n", "
243 | \n", "0.909800 | \n", "
244 | \n", "0.838400 | \n", "
245 | \n", "0.889200 | \n", "
246 | \n", "0.912900 | \n", "
247 | \n", "0.879700 | \n", "
248 | \n", "0.910700 | \n", "
249 | \n", "0.845400 | \n", "
250 | \n", "0.882200 | \n", "
"
],
"text/plain": [
"