{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "torch.cuda.is_available()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/cuda_setup/main.py:136: UserWarning: /opt/conda did not contain libcudart.so as expected! Searching further paths...\n", " warn(msg)\n", "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n", "The class this function is called from is 'LlamaTokenizer'.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "===================================BUG REPORT===================================\n", "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", "================================================================================\n", "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so\n", "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n", "CUDA SETUP: Detected CUDA version 113\n", "CUDA SETUP: Loading binary /home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...\n", "PeftConfig(peft_type='LORA', base_model_name_or_path='decapoda-research/llama-7b-hf', task_type='CASUAL_LM', inference_mode=True)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4581e257066e463d9638f5d3a407a70e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/33 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from transformers import LlamaTokenizer, LlamaForCausalLM\n", "from peft import get_peft_model, PeftConfig, PeftModel\n", "\n", "loc = 'sqllama-out3'\n", "\n", "config = PeftConfig.from_pretrained(loc)\n", "print(config)\n", "\n", "\n", "tokenizer = LlamaTokenizer.from_pretrained(\n", " \"decapoda-research/llama-7b-hf\")\n", " \n", "tokenizer.pad_token_id = 0\n", "tokenizer.padding_side = 'left'\n", "\n", "model = LlamaForCausalLM.from_pretrained(\n", " \"decapoda-research/llama-7b-hf\",\n", " load_in_8bit=True,\n", " device_map=\"auto\",\n", " torch_dtype=torch.float16\n", ")\n", "\n", "model = PeftModel.from_pretrained(\n", " model, loc,\n", " torch_dtype=torch.float16,\n", " device_map=\"auto\"\n", " )\n", "\n", "#model.push_to_hub('LlamaSQL-3')\n", "\n", "#model = prepare_model_for_int8_training(model)\n", "\n", "#model = get_peft_model(model,config)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "table: 1-20246201-9\n", "columns: Candidate,Office,Home state,Popular vote,States – first place,States – second place,States – third place\n", "Q: How many states-first place are there for the office of Governor?\n", "A: SELECT COUNT States – first place FROM 1-20246201-9 WHERE Office = 'Governor'\n", "END\n", "\n", "\n", "table: 2-17429402-7\n", "columns: School,Years of Participation,OCC Championships,Last OCC Championship,Last Outright OCC Championship\n", "Q: How many times have Central Crossing won the OCC Championship?\n", "A: SELECT SUM OCC Championships FROM 2-17429402-7 WHERE School = 'central crossing'\n", "END\n", "\n", "\n", "table: 1-11677691-10\n", "columns: Player,Position,School,Hometown,College\n", "Q: What town is Muscle Shoals High School located in?\n", "A: SELECT Hometown FROM 1-11677691-10 WHERE School = 'Muscle Shoals High School'\n", "END\n", "\n", "\n", "table: 2-10701914-2\n", "columns: Home team,Home team score,Away team,Away team score,Venue,Crowd,Date\n", "Q: What is the largest crowd when st kilda was the away team?\n", "A: SELECT MAX Crowd FROM 2-10701914-2 WHERE Away team = 'st kilda'\n", "END\n", "\n", "\n", "table: 2-1122152-1\n", "columns: Driver,Constructor,Laps,Time/Retired,Grid\n", "Q: What is the lap total for the grid under 15 that retired due to transmission?\n", "A: SELECT SUM Laps FROM 2-1122152-1 WHERE Grid < 15 AND Time/Retired = 'transmission'\n", "END\n", "\n" ] } ], "source": [ "import random\n", "import json\n", "\n", "# defined by WikiSQL\n", "\n", "agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']\n", "cond_ops = ['=', '>', '<', 'OP']\n", "syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']\n", "\n", "def fix_repr(d,cols,types,tid):\n", " sel_index=d['sel'] \n", " agg_index=d['agg']\n", " conditions=d['conds']\n", " col = cols[sel_index]\n", " rep = 'SELECT {agg} {sel} FROM {tid}'.format(\n", " agg=agg_ops[agg_index],\n", " sel=col,\n", " tid=tid\n", " )\n", " if conditions:\n", " cs = []\n", " for i, o, v in conditions:\n", " #print(i,cols)\n", " nm = cols[i]\n", " op = cond_ops[o]\n", " \n", " if types[i] in ['text']:\n", " val = f\"\\'{v}\\'\"\n", " else:\n", " val = v\n", " cs.append(f'{nm} {op} {val}')\n", " #print(cs)\n", "\n", " rep += ' WHERE ' + ' AND '.join(cs)\n", " \n", " return rep\n", "\n", "tbl_cols = {}\n", "tbl_types = {}\n", "tbl_str = {}\n", "\n", "prefix = 'Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.'\n", "\n", "def tbl_def_to_string(id, header, types):\n", " s = f'table: {id}\\ncolumns: ' + ','.join(header)\n", " return s\n", "\n", "with open('data/train.tables.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['id']\n", " hdr = js['header']\n", " ts = js['types']\n", " tbl_str[id] = tbl_def_to_string(id,hdr,ts)\n", " tbl_cols[id] = hdr\n", " tbl_types[id] = ts\n", "\n", "q_s = []\n", "a_s = []\n", "\n", "with open('data/train.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['table_id']\n", " s = tbl_str[id]\n", " qst = js['question']\n", " nl = s + '\\nQ: ' + qst + '\\nA: '\n", " q_s.append(nl)\n", "\n", " sql = js['sql']\n", " a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n", " a = a + \"\\nEND\\n\"\n", " a_s.append(a)\n", "\n", "M = len(q_s)\n", "\n", "data_txt = [q_s[i] + a_s[i] for i in range(M)]\n", "\n", "for i in range(5):\n", " j = random.randint(0,M-1)\n", " print()\n", " print(data_txt[j]) \n", " \n", " " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "toks = [tokenizer(s) for s in data_txt]\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "92\n", " 0\n", "count 56355.000000\n", "mean 101.219519\n", "std 21.740325\n", "min 63.000000\n", "25% 87.500000\n", "50% 97.000000\n", "75% 109.000000\n", "max 461.000000\n", "32084\n" ] } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "print(len(toks[0].input_ids))\n", "lens = np.array([len(tok.input_ids) for tok in toks])\n", "print(pd.DataFrame(lens).describe())\n", "\n", "z = zip(q_s,lens)\n", "q_red = [a for a,b in z if b < 100]\n", "z = zip(a_s,lens)\n", "a_red = [a for a,b in z if b < 100]\n", "\n", "data_red = [q_red[i] + a_red[i] for i in range(len(q_red))]\n", "print(len(data_red))\n", "\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3215cace7aef45b1b040a12f11509a7d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/32084 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import random, datasets\n", "#d = {'prompt': random.sample(data_red, 1000)}\n", "d = {'prompt': data_red}\n", "\n", "data = datasets.Dataset.from_dict(d)\n", "data = data.map(lambda x:\n", " tokenizer(\n", " x['prompt'],\n", " truncation=True,\n", " max_length=100,\n", " padding=\"max_length\"\n", " ))\n", "\n", "data = data.remove_columns('prompt')\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "#from peft import get_peft_model,PrefixTuningConfig, TaskType, PeftType\n", "from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType\n", "import torch\n", "import transformers\n", "import datasets\n", "\n", "BATCH = 128\n", "MICRO_BATCH = 4\n", "N_GAS = BATCH//MICRO_BATCH\n", "EPOCHS = 1\n", "LR = 1e-6\n", "\n", "#peft_cfg = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=30)\n", "#model = get_peft_model(model,peft_config)\n", "#model = model.to(torch.device('cuda'))\n", "\n", "targs = transformers.TrainingArguments(\n", " per_device_train_batch_size=MICRO_BATCH,\n", " gradient_accumulation_steps=N_GAS,\n", " warmup_steps=20,\n", " num_train_epochs=EPOCHS,\n", " learning_rate=LR,\n", " fp16=True,\n", " logging_steps=1,\n", " output_dir='sqllama-out3-rt',\n", " save_total_limit=3,\n", " remove_unused_columns=False,\n", " \n", ")\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "
---|---|
1 | \n", "1.121900 | \n", "
2 | \n", "1.103000 | \n", "
3 | \n", "1.128600 | \n", "
4 | \n", "1.098400 | \n", "
5 | \n", "1.094800 | \n", "
6 | \n", "1.135100 | \n", "
7 | \n", "1.114100 | \n", "
8 | \n", "1.125100 | \n", "
9 | \n", "1.122000 | \n", "
10 | \n", "1.138500 | \n", "
11 | \n", "1.114100 | \n", "
12 | \n", "1.124700 | \n", "
13 | \n", "1.133100 | \n", "
14 | \n", "1.118900 | \n", "
15 | \n", "1.121700 | \n", "
16 | \n", "1.131600 | \n", "
17 | \n", "1.137700 | \n", "
18 | \n", "1.147000 | \n", "
19 | \n", "1.113200 | \n", "
20 | \n", "1.158000 | \n", "
21 | \n", "1.114100 | \n", "
22 | \n", "1.117700 | \n", "
23 | \n", "1.151200 | \n", "
24 | \n", "1.154200 | \n", "
25 | \n", "1.153600 | \n", "
26 | \n", "1.119700 | \n", "
27 | \n", "1.138500 | \n", "
28 | \n", "1.148200 | \n", "
29 | \n", "1.103300 | \n", "
30 | \n", "1.117700 | \n", "
31 | \n", "1.116300 | \n", "
32 | \n", "1.134400 | \n", "
33 | \n", "1.125200 | \n", "
34 | \n", "1.137900 | \n", "
35 | \n", "1.150100 | \n", "
36 | \n", "1.126200 | \n", "
37 | \n", "1.129100 | \n", "
38 | \n", "1.093200 | \n", "
39 | \n", "1.153100 | \n", "
40 | \n", "1.108000 | \n", "
41 | \n", "1.137300 | \n", "
42 | \n", "1.101600 | \n", "
43 | \n", "1.140600 | \n", "
44 | \n", "1.159900 | \n", "
45 | \n", "1.112600 | \n", "
46 | \n", "1.101200 | \n", "
47 | \n", "1.088000 | \n", "
48 | \n", "1.135300 | \n", "
49 | \n", "1.118100 | \n", "
50 | \n", "1.140300 | \n", "
51 | \n", "1.104000 | \n", "
52 | \n", "1.122900 | \n", "
53 | \n", "1.162200 | \n", "
54 | \n", "1.108500 | \n", "
55 | \n", "1.121900 | \n", "
56 | \n", "1.092100 | \n", "
57 | \n", "1.109500 | \n", "
58 | \n", "1.139400 | \n", "
59 | \n", "1.120800 | \n", "
60 | \n", "1.132200 | \n", "
61 | \n", "1.138700 | \n", "
62 | \n", "1.128700 | \n", "
63 | \n", "1.122500 | \n", "
64 | \n", "1.145800 | \n", "
65 | \n", "1.135000 | \n", "
66 | \n", "1.107900 | \n", "
67 | \n", "1.120700 | \n", "
68 | \n", "1.128000 | \n", "
69 | \n", "1.107600 | \n", "
70 | \n", "1.155700 | \n", "
71 | \n", "1.142400 | \n", "
72 | \n", "1.118900 | \n", "
73 | \n", "1.129900 | \n", "
74 | \n", "1.134400 | \n", "
75 | \n", "1.105500 | \n", "
76 | \n", "1.104100 | \n", "
77 | \n", "1.100900 | \n", "
78 | \n", "1.148200 | \n", "
79 | \n", "1.116100 | \n", "
80 | \n", "1.121700 | \n", "
81 | \n", "1.154100 | \n", "
82 | \n", "1.118900 | \n", "
83 | \n", "1.109600 | \n", "
84 | \n", "1.109300 | \n", "
85 | \n", "1.147900 | \n", "
86 | \n", "1.094300 | \n", "
87 | \n", "1.130000 | \n", "
88 | \n", "1.095100 | \n", "
89 | \n", "1.145900 | \n", "
90 | \n", "1.131600 | \n", "
91 | \n", "1.114200 | \n", "
92 | \n", "1.126600 | \n", "
93 | \n", "1.100300 | \n", "
94 | \n", "1.140900 | \n", "
95 | \n", "1.132800 | \n", "
96 | \n", "1.105900 | \n", "
97 | \n", "1.106200 | \n", "
98 | \n", "1.097400 | \n", "
99 | \n", "1.114500 | \n", "
100 | \n", "1.113700 | \n", "
101 | \n", "1.093300 | \n", "
102 | \n", "1.121900 | \n", "
103 | \n", "1.133600 | \n", "
104 | \n", "1.131500 | \n", "
105 | \n", "1.136800 | \n", "
106 | \n", "1.130800 | \n", "
107 | \n", "1.102100 | \n", "
108 | \n", "1.128300 | \n", "
109 | \n", "1.163500 | \n", "
110 | \n", "1.144200 | \n", "
111 | \n", "1.125600 | \n", "
112 | \n", "1.119700 | \n", "
113 | \n", "1.111100 | \n", "
114 | \n", "1.122400 | \n", "
115 | \n", "1.142500 | \n", "
116 | \n", "1.124500 | \n", "
117 | \n", "1.117700 | \n", "
118 | \n", "1.130500 | \n", "
119 | \n", "1.118500 | \n", "
120 | \n", "1.097200 | \n", "
121 | \n", "1.123600 | \n", "
122 | \n", "1.135700 | \n", "
123 | \n", "1.153400 | \n", "
124 | \n", "1.088200 | \n", "
125 | \n", "1.123600 | \n", "
126 | \n", "1.143000 | \n", "
127 | \n", "1.121800 | \n", "
128 | \n", "1.091200 | \n", "
129 | \n", "1.116700 | \n", "
130 | \n", "1.124400 | \n", "
131 | \n", "1.139100 | \n", "
132 | \n", "1.119400 | \n", "
133 | \n", "1.115000 | \n", "
134 | \n", "1.133600 | \n", "
135 | \n", "1.100900 | \n", "
136 | \n", "1.095100 | \n", "
137 | \n", "1.142600 | \n", "
138 | \n", "1.097300 | \n", "
139 | \n", "1.113100 | \n", "
140 | \n", "1.150800 | \n", "
141 | \n", "1.149600 | \n", "
142 | \n", "1.106700 | \n", "
143 | \n", "1.086100 | \n", "
144 | \n", "1.134200 | \n", "
145 | \n", "1.096400 | \n", "
146 | \n", "1.099200 | \n", "
147 | \n", "1.168300 | \n", "
148 | \n", "1.105900 | \n", "
149 | \n", "1.119700 | \n", "
150 | \n", "1.100200 | \n", "
151 | \n", "1.089600 | \n", "
152 | \n", "1.128200 | \n", "
153 | \n", "1.148300 | \n", "
154 | \n", "1.119800 | \n", "
155 | \n", "1.102700 | \n", "
156 | \n", "1.107800 | \n", "
157 | \n", "1.113100 | \n", "
158 | \n", "1.156100 | \n", "
159 | \n", "1.091500 | \n", "
160 | \n", "1.118000 | \n", "
161 | \n", "1.145600 | \n", "
162 | \n", "1.115400 | \n", "
163 | \n", "1.121900 | \n", "
164 | \n", "1.130100 | \n", "
165 | \n", "1.123400 | \n", "
166 | \n", "1.090900 | \n", "
167 | \n", "1.144400 | \n", "
168 | \n", "1.125100 | \n", "
169 | \n", "1.110700 | \n", "
170 | \n", "1.134300 | \n", "
171 | \n", "1.092600 | \n", "
172 | \n", "1.123000 | \n", "
173 | \n", "1.080100 | \n", "
174 | \n", "1.104100 | \n", "
175 | \n", "1.105800 | \n", "
176 | \n", "1.156000 | \n", "
177 | \n", "1.104000 | \n", "
178 | \n", "1.118500 | \n", "
179 | \n", "1.123100 | \n", "
180 | \n", "1.117000 | \n", "
181 | \n", "1.122100 | \n", "
182 | \n", "1.141200 | \n", "
183 | \n", "1.135600 | \n", "
184 | \n", "1.093600 | \n", "
185 | \n", "1.156300 | \n", "
186 | \n", "1.095600 | \n", "
187 | \n", "1.128900 | \n", "
188 | \n", "1.101200 | \n", "
189 | \n", "1.149900 | \n", "
190 | \n", "1.112300 | \n", "
191 | \n", "1.117600 | \n", "
192 | \n", "1.090600 | \n", "
193 | \n", "1.097700 | \n", "
194 | \n", "1.084700 | \n", "
195 | \n", "1.128900 | \n", "
196 | \n", "1.126400 | \n", "
197 | \n", "1.113000 | \n", "
198 | \n", "1.107500 | \n", "
199 | \n", "1.160100 | \n", "
200 | \n", "1.125800 | \n", "
201 | \n", "1.125300 | \n", "
202 | \n", "1.127200 | \n", "
203 | \n", "1.114200 | \n", "
204 | \n", "1.114300 | \n", "
205 | \n", "1.119200 | \n", "
206 | \n", "1.114500 | \n", "
207 | \n", "1.086100 | \n", "
208 | \n", "1.096200 | \n", "
209 | \n", "1.115800 | \n", "
210 | \n", "1.094500 | \n", "
211 | \n", "1.106400 | \n", "
212 | \n", "1.121400 | \n", "
213 | \n", "1.137600 | \n", "
214 | \n", "1.107000 | \n", "
215 | \n", "1.095700 | \n", "
216 | \n", "1.083000 | \n", "
217 | \n", "1.088700 | \n", "
218 | \n", "1.133700 | \n", "
219 | \n", "1.115500 | \n", "
220 | \n", "1.152900 | \n", "
221 | \n", "1.100100 | \n", "
222 | \n", "1.112500 | \n", "
223 | \n", "1.119200 | \n", "
224 | \n", "1.122600 | \n", "
225 | \n", "1.100100 | \n", "
226 | \n", "1.082500 | \n", "
227 | \n", "1.094800 | \n", "
228 | \n", "1.123600 | \n", "
229 | \n", "1.124700 | \n", "
230 | \n", "1.148800 | \n", "
231 | \n", "1.109600 | \n", "
232 | \n", "1.096100 | \n", "
233 | \n", "1.123000 | \n", "
234 | \n", "1.102200 | \n", "
235 | \n", "1.113200 | \n", "
236 | \n", "1.150700 | \n", "
237 | \n", "1.131900 | \n", "
238 | \n", "1.107200 | \n", "
239 | \n", "1.137600 | \n", "
240 | \n", "1.094800 | \n", "
241 | \n", "1.068000 | \n", "
242 | \n", "1.122100 | \n", "
243 | \n", "1.153700 | \n", "
244 | \n", "1.045100 | \n", "
245 | \n", "1.131400 | \n", "
246 | \n", "1.134600 | \n", "
247 | \n", "1.105300 | \n", "
248 | \n", "1.108800 | \n", "
249 | \n", "1.080800 | \n", "
250 | \n", "1.119200 | \n", "
"
],
"text/plain": [
"