matt-tries-dl commited on
Commit
b444d89
1 Parent(s): 357d6d7
Files changed (3) hide show
  1. alpaca-lora +1 -0
  2. llama_test.ipynb +209 -29
  3. requirements.txt +2 -1
alpaca-lora ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 8bb8579e403dc78e37fe81ffbb253c413007323f
llama_test.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 13,
6
  "metadata": {},
7
  "outputs": [
8
  {
@@ -11,7 +11,7 @@
11
  "True"
12
  ]
13
  },
14
- "execution_count": 13,
15
  "metadata": {},
16
  "output_type": "execute_result"
17
  }
@@ -32,7 +32,7 @@
32
  },
33
  {
34
  "cell_type": "code",
35
- "execution_count": 14,
36
  "metadata": {},
37
  "outputs": [
38
  {
@@ -47,7 +47,7 @@
47
  {
48
  "data": {
49
  "application/vnd.jupyter.widget-view+json": {
50
- "model_id": "ca1fb983d9884b91a3c0feed1e207d0e",
51
  "version_major": 2,
52
  "version_minor": 0
53
  },
@@ -83,7 +83,7 @@
83
  },
84
  {
85
  "cell_type": "code",
86
- "execution_count": 15,
87
  "metadata": {},
88
  "outputs": [
89
  {
@@ -132,7 +132,7 @@
132
  },
133
  {
134
  "cell_type": "code",
135
- "execution_count": 16,
136
  "metadata": {},
137
  "outputs": [
138
  {
@@ -168,7 +168,7 @@
168
  },
169
  {
170
  "cell_type": "code",
171
- "execution_count": 17,
172
  "metadata": {},
173
  "outputs": [
174
  {
@@ -232,7 +232,7 @@
232
  },
233
  {
234
  "cell_type": "code",
235
- "execution_count": 56,
236
  "metadata": {},
237
  "outputs": [
238
  {
@@ -240,25 +240,30 @@
240
  "output_type": "stream",
241
  "text": [
242
  "\n",
243
- "Respond to the following data request with a SQL query.\n",
244
- "Q: Table 2-16763320-1 has columns Tournament (text),Surface (text),Week (text),Winner (text),Finalist (text),Semifinalists (text). Which finalist has Semifinalists of andre agassi (1) lleyton hewitt (14)?\n",
245
- "A: SELECT Finalist FROM 2-16763320-1 WHERE Semifinalists = 'andre agassi (1) lleyton hewitt (14)'\n",
 
246
  "\n",
247
- "Respond to the following data request with a SQL query.\n",
248
- "Q: Table 1-27755784-10 has columns Game (real),Date (text),Team (text),Score (text),High points (text),High rebounds (text),High assists (text),Location Attendance (text),Record (text). What is the highest game number?\n",
249
- "A: SELECT MAX Game FROM 1-27755784-10\n",
 
250
  "\n",
251
- "Respond to the following data request with a SQL query.\n",
252
- "Q: Table 2-17231086-5 has columns Place (text),Player (text),Country (text),Score (text),To par (text). What place is the United States in that has a score of 68-73-68=209?\n",
253
- "A: SELECT Place FROM 2-17231086-5 WHERE Country = 'united states' AND Score = '68-73-68=209'\n",
 
254
  "\n",
255
- "Respond to the following data request with a SQL query.\n",
256
- "Q: Table 2-1302729-1 has columns Season (real),Overall (text),Slalom (text),Giant Slalom (text),Super G (text),Downhill (text),Combined (text). What is the combined of 2 overalls and 5 slaloms?\n",
257
- "A: SELECT Combined FROM 2-1302729-1 WHERE Overall = '2' AND Slalom = '5'\n",
 
258
  "\n",
259
- "Respond to the following data request with a SQL query.\n",
260
- "Q: Table 2-15295737-56 has columns Nation (text),Skip (text),Third (text),Second (text),Lead (text),Alternate (text). Who is the alternate for the team for which Monika Wagner is the third?\n",
261
- "A: SELECT Alternate FROM 2-15295737-56 WHERE Third = 'monika wagner'\n"
 
262
  ]
263
  }
264
  ],
@@ -303,11 +308,11 @@
303
  "tbl_types = {}\n",
304
  "tbl_str = {}\n",
305
  "\n",
306
- "prefix = 'Respond to the following data request with a SQL query.\\n'\n",
307
  "\n",
308
  "def tbl_def_to_string(id, header, types):\n",
309
  " ht = [f'{header[i]} ({types[i]})' for i in range(len(header))]\n",
310
- " s = f'Q: Table {id} has columns ' + ','.join(ht) + '. '\n",
311
  " return s\n",
312
  "\n",
313
  "with open('data/train.tables.jsonl') as f:\n",
@@ -330,26 +335,201 @@
330
  " id = js['table_id']\n",
331
  " s = tbl_str[id]\n",
332
  " qst = js['question']\n",
333
- " nl = prefix + s + qst\n",
334
  " nl_q.append(nl)\n",
335
  "\n",
336
  " sql = js['sql']\n",
337
  " a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n",
338
- " a = 'A: ' + a\n",
339
  " sql_a.append(a)\n",
340
  "\n",
341
  "\n",
342
  "M = len(nl_q)\n",
343
  "\n",
 
344
  "\n",
345
  "for i in range(5):\n",
346
  " j = random.randint(0,M-1)\n",
347
  " print()\n",
348
- " print(nl_q[j])\n",
349
- " print(sql_a[j]) \n",
350
  " \n",
351
  " "
352
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  }
354
  ],
355
  "metadata": {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "metadata": {},
7
  "outputs": [
8
  {
 
11
  "True"
12
  ]
13
  },
14
+ "execution_count": 1,
15
  "metadata": {},
16
  "output_type": "execute_result"
17
  }
 
32
  },
33
  {
34
  "cell_type": "code",
35
+ "execution_count": 2,
36
  "metadata": {},
37
  "outputs": [
38
  {
 
47
  {
48
  "data": {
49
  "application/vnd.jupyter.widget-view+json": {
50
+ "model_id": "3ab80e2a1c0744e0af747ba63429a2af",
51
  "version_major": 2,
52
  "version_minor": 0
53
  },
 
83
  },
84
  {
85
  "cell_type": "code",
86
+ "execution_count": 3,
87
  "metadata": {},
88
  "outputs": [
89
  {
 
132
  },
133
  {
134
  "cell_type": "code",
135
+ "execution_count": 13,
136
  "metadata": {},
137
  "outputs": [
138
  {
 
168
  },
169
  {
170
  "cell_type": "code",
171
+ "execution_count": 4,
172
  "metadata": {},
173
  "outputs": [
174
  {
 
232
  },
233
  {
234
  "cell_type": "code",
235
+ "execution_count": 5,
236
  "metadata": {},
237
  "outputs": [
238
  {
 
240
  "output_type": "stream",
241
  "text": [
242
  "\n",
243
+ "Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n",
244
+ "### Question: What is the Displacement of the Iveco F1CE3481E Engine?\n",
245
+ "### Input: Table 2-1415821-6 has columns Model (text),Engine (text),Displacement (text),Valvetrain (text),Fuel system (text),Max. power at rpm (text),Max. torque at rpm (text). \n",
246
+ "### Answer: SELECT Displacement FROM 2-1415821-6 WHERE Engine = 'iveco f1ce3481e'\n",
247
  "\n",
248
+ "Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n",
249
+ "### Question: What is the record of team utah?\n",
250
+ "### Input: Table 2-17355628-9 has columns Game (real),Date (text),Team (text),Score (text),High points (text),High rebounds (text),High assists (text),Location Attendance (text),Record (text). \n",
251
+ "### Answer: SELECT Record FROM 2-17355628-9 WHERE Team = 'utah'\n",
252
  "\n",
253
+ "Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n",
254
+ "### Question: What is the home of the team with a 16-8 record?\n",
255
+ "### Input: Table 2-16188254-4 has columns Date (text),Visitor (text),Score (text),Home (text),Leading scorer (text),Attendance (text),Record (text). \n",
256
+ "### Answer: SELECT Home FROM 2-16188254-4 WHERE Record = '16-8'\n",
257
  "\n",
258
+ "Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n",
259
+ "### Question: What week did the Galaxy play the Amsterdam Admirals?\n",
260
+ "### Input: Table 1-24814477-2 has columns Week (real),Date (text),Kickoff (text),Opponent (text),Final score (text),Team record (text),Game site (text),Attendance (real). \n",
261
+ "### Answer: SELECT Week FROM 1-24814477-2 WHERE Opponent = 'Amsterdam Admirals'\n",
262
  "\n",
263
+ "Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n",
264
+ "### Question: How many caps did Mitchell Duke have overall?\n",
265
+ "### Input: Table 2-1257177-1 has columns Player (text),Country (text),Caps (real),Goals (text),Years Active (text). \n",
266
+ "### Answer: SELECT COUNT Caps FROM 2-1257177-1 WHERE Player = 'mitchell duke'\n"
267
  ]
268
  }
269
  ],
 
308
  "tbl_types = {}\n",
309
  "tbl_str = {}\n",
310
  "\n",
311
+ "prefix = 'Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.'\n",
312
  "\n",
313
  "def tbl_def_to_string(id, header, types):\n",
314
  " ht = [f'{header[i]} ({types[i]})' for i in range(len(header))]\n",
315
+ " s = f'\\n### Input: Table {id} has columns ' + ','.join(ht) + '. '\n",
316
  " return s\n",
317
  "\n",
318
  "with open('data/train.tables.jsonl') as f:\n",
 
335
  " id = js['table_id']\n",
336
  " s = tbl_str[id]\n",
337
  " qst = js['question']\n",
338
+ " nl = prefix + \"\\n### Question: \" + qst + s\n",
339
  " nl_q.append(nl)\n",
340
  "\n",
341
  " sql = js['sql']\n",
342
  " a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n",
343
+ " a = '\\n### Answer: ' + a\n",
344
  " sql_a.append(a)\n",
345
  "\n",
346
  "\n",
347
  "M = len(nl_q)\n",
348
  "\n",
349
+ "data_txt = [nl_q[i] + sql_a[i] for i in range(len(nl_q))]\n",
350
  "\n",
351
  "for i in range(5):\n",
352
  " j = random.randint(0,M-1)\n",
353
  " print()\n",
354
+ " print(data_txt[j]) \n",
 
355
  " \n",
356
  " "
357
  ]
358
+ },
359
+ {
360
+ "attachments": {},
361
+ "cell_type": "markdown",
362
+ "metadata": {},
363
+ "source": [
364
+ "Set up the details for the model."
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "code",
369
+ "execution_count": 26,
370
+ "metadata": {},
371
+ "outputs": [
372
+ {
373
+ "data": {
374
+ "application/vnd.jupyter.widget-view+json": {
375
+ "model_id": "4f44918087484dd58b958a64cabdecb6",
376
+ "version_major": 2,
377
+ "version_minor": 0
378
+ },
379
+ "text/plain": [
380
+ "Map: 0%| | 0/56355 [00:00<?, ? examples/s]"
381
+ ]
382
+ },
383
+ "metadata": {},
384
+ "output_type": "display_data"
385
+ }
386
+ ],
387
+ "source": [
388
+ "from peft import LoraConfig, get_peft_model\n",
389
+ "import transformers\n",
390
+ "import datasets\n",
391
+ "\n",
392
+ "LORA_R = 4\n",
393
+ "LORA_ALPHA = 16\n",
394
+ "LORA_DROPOUT = .1\n",
395
+ "CUTOFF_LEN = 256\n",
396
+ "BATCH = 128\n",
397
+ "MICRO_BATCH = 4\n",
398
+ "N_GAS = BATCH//MICRO_BATCH\n",
399
+ "EPOCHS = 1\n",
400
+ "LR = 1e-5\n",
401
+ "\n",
402
+ "lora_cfg = LoraConfig(\n",
403
+ " r = LORA_R,\n",
404
+ " lora_alpha=LORA_ALPHA,\n",
405
+ " lora_dropout=LORA_DROPOUT,\n",
406
+ " task_type='CASUAL_LM',\n",
407
+ " target_modules=['q_proj','v_proj']\n",
408
+ ")\n",
409
+ "\n",
410
+ "modad = get_peft_model(model,lora_cfg)\n",
411
+ "\n",
412
+ "tokenizer.pad_token_id = 0\n",
413
+ "\n",
414
+ "d = {'prompt': data_txt}\n",
415
+ "\n",
416
+ "data = datasets.Dataset.from_dict(d)\n",
417
+ "data = data.map(lambda x:\n",
418
+ " tokenizer(\n",
419
+ " x['prompt'],\n",
420
+ " truncation=True,\n",
421
+ " max_length=CUTOFF_LEN,\n",
422
+ " padding=\"max_length\"\n",
423
+ " ))\n",
424
+ "\n",
425
+ "#data.remove_columns('prompt')\n",
426
+ "\n",
427
+ "targs = transformers.TrainingArguments(\n",
428
+ " per_device_train_batch_size=MICRO_BATCH,\n",
429
+ " gradient_accumulation_steps=N_GAS,\n",
430
+ " warmup_steps=0,\n",
431
+ " num_train_epochs=EPOCHS,\n",
432
+ " learning_rate=LR,\n",
433
+ " fp16=True,\n",
434
+ " logging_steps=1,\n",
435
+ " output_dir='sqllama-out',\n",
436
+ " save_total_limit=3,\n",
437
+ " remove_unused_columns=False\n",
438
+ ")\n",
439
+ "\n",
440
+ "\n",
441
+ "modad.config.use_cache = False"
442
+ ]
443
+ },
444
+ {
445
+ "attachments": {},
446
+ "cell_type": "markdown",
447
+ "metadata": {},
448
+ "source": [
449
+ "ignore - just trying to figure out huggingface datasets"
450
+ ]
451
+ },
452
+ {
453
+ "cell_type": "code",
454
+ "execution_count": 27,
455
+ "metadata": {},
456
+ "outputs": [
457
+ {
458
+ "name": "stdout",
459
+ "output_type": "stream",
460
+ "text": [
461
+ "Dataset({\n",
462
+ " features: ['prompt', 'input_ids', 'attention_mask'],\n",
463
+ " num_rows: 56355\n",
464
+ "})\n",
465
+ "{'prompt': \"Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\\n### Question: Tell me what the notes are for South Australia \\n### Input: Table 1-1000181-1 has columns State/territory (text),Text/background colour (text),Format (text),Current slogan (text),Current series (text),Notes (text). \\n### Answer: SELECT Notes FROM 1-1000181-1 WHERE Current slogan = 'SOUTH AUSTRALIA'\", 'input_ids': [0, 13866, 338, 263, 1139, 393, 16612, 263, 848, 2009, 29892, 3300, 2859, 411, 385, 1881, 393, 16612, 263, 3758, 1591, 29889, 29871, 14350, 263, 3758, 2346, 393, 5663, 17180, 278, 848, 29889, 13, 2277, 29937, 894, 29901, 24948, 592, 825, 278, 11486, 526, 363, 4275, 8314, 29871, 13, 2277, 29937, 10567, 29901, 6137, 29871, 29896, 29899, 29896, 29900, 29900, 29900, 29896, 29947, 29896, 29899, 29896, 756, 4341, 4306, 29914, 357, 768, 706, 313, 726, 511, 1626, 29914, 7042, 12384, 313, 726, 511, 5809, 313, 726, 511, 7583, 269, 1188, 273, 313, 726, 511, 7583, 3652, 313, 726, 511, 3664, 267, 313, 726, 467, 259, 13, 2277, 29937, 673, 29901, 5097, 29871, 8695, 3895, 29871, 29896, 29899, 29896, 29900, 29900, 29900, 29896, 29947, 29896, 29899, 29896, 5754, 9626, 269, 1188, 273, 353, 525, 6156, 2692, 29950, 319, 29965, 10810, 1964, 10764, 29915, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}\n"
466
+ ]
467
+ }
468
+ ],
469
+ "source": [
470
+ "print(data)\n",
471
+ "print(data[0])\n",
472
+ "\n",
473
+ "#from datasets import load_dataset\n",
474
+ "\n",
475
+ "\n",
476
+ "#!git clone https://github.com/tloen/alpaca-lora.git\n",
477
+ "#dalp = load_dataset(\"json\", data_files=\"alpaca-lora/alpaca_data.json\")\n",
478
+ "#print(dalp)\n",
479
+ "\n",
480
+ "#dalp = dalp.map(lambda x : {'blah':'blah'})\n",
481
+ "#print(dalp)"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "code",
486
+ "execution_count": 25,
487
+ "metadata": {},
488
+ "outputs": [
489
+ {
490
+ "name": "stderr",
491
+ "output_type": "stream",
492
+ "text": [
493
+ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/optimization.py:395: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
494
+ " FutureWarning,\n"
495
+ ]
496
+ },
497
+ {
498
+ "ename": "ValueError",
499
+ "evalue": "Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`prompt` in this case) have excessive nesting (inputs type `list` where type `int` is expected).",
500
+ "output_type": "error",
501
+ "traceback": [
502
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
503
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
504
+ "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/tokenization_utils_base.py\u001b[0m in \u001b[0;36mconvert_to_tensors\u001b[0;34m(self, tensor_type, prepend_batch_axis)\u001b[0m\n\u001b[1;32m 716\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 717\u001b[0;31m \u001b[0mtensor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mas_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 718\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
505
+ "\u001b[0;31mValueError\u001b[0m: too many dimensions 'str'",
506
+ "\nThe above exception was the direct cause of the following exception:\n",
507
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
508
+ "\u001b[0;32m/var/tmp/ipykernel_2309/3549391384.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mdata_collator\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtransformers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataCollatorForLanguageModeling\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtokenizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmlm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m )\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_pretrained\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'sqllama-out'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
509
+ "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1664\u001b[0m \u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1665\u001b[0m \u001b[0mtrial\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrial\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1666\u001b[0;31m \u001b[0mignore_keys_for_eval\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mignore_keys_for_eval\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1667\u001b[0m )\n\u001b[1;32m 1668\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
510
+ "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/trainer.py\u001b[0m in \u001b[0;36m_inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1897\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1898\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1899\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1900\u001b[0m \u001b[0mtotal_batched_samples\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1901\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrng_to_sync\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
511
+ "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[0;31m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[call-arg]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 628\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 629\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 630\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
512
+ "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 669\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 670\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 671\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 672\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 673\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory_device\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
513
+ "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
514
+ "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/data/data_collator.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, features, return_tensors)\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtf_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mreturn_tensors\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"pt\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 45\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtorch_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 46\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mreturn_tensors\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"np\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
515
+ "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/data/data_collator.py\u001b[0m in \u001b[0;36mtorch_call\u001b[0;34m(self, examples)\u001b[0m\n\u001b[1;32m 727\u001b[0m \u001b[0;31m# Handle dict or lists with proper padding and conversion to tensor.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 728\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexamples\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 729\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexamples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_tensors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"pt\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpad_to_multiple_of\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpad_to_multiple_of\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 730\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 731\u001b[0m batch = {\n",
516
+ "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/tokenization_utils_base.py\u001b[0m in \u001b[0;36mpad\u001b[0;34m(self, encoded_inputs, padding, max_length, pad_to_multiple_of, return_attention_mask, return_tensors, verbose)\u001b[0m\n\u001b[1;32m 3033\u001b[0m \u001b[0mbatch_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3034\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3035\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mBatchEncoding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_tensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3036\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3037\u001b[0m def create_token_type_ids_from_sequences(\n",
517
+ "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/tokenization_utils_base.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, data, encoding, tensor_type, prepend_batch_axis, n_sequences)\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_sequences\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mn_sequences\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 210\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert_to_tensors\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtensor_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprepend_batch_axis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mprepend_batch_axis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 211\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
518
+ "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/tokenization_utils_base.py\u001b[0m in \u001b[0;36mconvert_to_tensors\u001b[0;34m(self, tensor_type, prepend_batch_axis)\u001b[0m\n\u001b[1;32m 736\u001b[0m \u001b[0;34mf\" features (`{key}` in this case) have excessive nesting (inputs type `list` where type `int` is\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[0;34m\" expected).\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 738\u001b[0;31m ) from e\n\u001b[0m\u001b[1;32m 739\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 740\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
519
+ "\u001b[0;31mValueError\u001b[0m: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`prompt` in this case) have excessive nesting (inputs type `list` where type `int` is expected)."
520
+ ]
521
+ }
522
+ ],
523
+ "source": [
524
+ "trainer = transformers.Trainer(\n",
525
+ " model = modad,\n",
526
+ " train_dataset = data,\n",
527
+ " args = targs,\n",
528
+ " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
529
+ ")\n",
530
+ "trainer.train(resume_from_checkpoint=False)\n",
531
+ "model.save_pretrained('sqllama-out')"
532
+ ]
533
  }
534
  ],
535
  "metadata": {
requirements.txt CHANGED
@@ -5,8 +5,9 @@ torch
5
  sentencepiece
6
  transformers
7
  accelerate
8
- bitsandbytes
9
  peft
 
10
  tqdm
11
  records
12
  babel
 
5
  sentencepiece
6
  transformers
7
  accelerate
8
+ bitsandbytes==0.37.2
9
  peft
10
+ datasets
11
  tqdm
12
  records
13
  babel