{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "torch.cuda.is_available()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/cuda_setup/main.py:136: UserWarning: /opt/conda did not contain libcudart.so as expected! Searching further paths...\n", " warn(msg)\n", "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n", "The class this function is called from is 'LlamaTokenizer'.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "===================================BUG REPORT===================================\n", "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", "================================================================================\n", "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so\n", "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n", "CUDA SETUP: Detected CUDA version 113\n", "CUDA SETUP: Loading binary /home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0b970b9989854c33aa4d19fa9457d7a2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/33 [00:00', '<', 'OP']\n", "syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']\n", "\n", "def fix_repr(d,cols,types,tid):\n", " sel_index=d['sel'] \n", " agg_index=d['agg']\n", " conditions=d['conds']\n", " col = cols[sel_index]\n", " rep = 'SELECT {agg} {sel} FROM {tid}'.format(\n", " agg=agg_ops[agg_index],\n", " sel=col,\n", " tid=tid\n", " )\n", " if conditions:\n", " cs = []\n", " for i, o, v in conditions:\n", " #print(i,cols)\n", " nm = cols[i]\n", " op = cond_ops[o]\n", " \n", " if types[i] in ['text']:\n", " val = f\"\\'{v}\\'\"\n", " else:\n", " val = v\n", " cs.append(f'{nm} {op} {val}')\n", " #print(cs)\n", "\n", " rep += ' WHERE ' + ' AND '.join(cs)\n", " \n", " return rep\n", "\n", "tbl_cols = {}\n", "tbl_types = {}\n", "tbl_str = {}\n", "\n", "prefix = 'Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.'\n", "\n", "def tbl_def_to_string(id, header, types):\n", " s = f'table: {id}\\ncolumns: ' + ','.join(header)\n", " return s\n", "\n", "with open('data/train.tables.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['id']\n", " hdr = js['header']\n", " ts = js['types']\n", " tbl_str[id] = tbl_def_to_string(id,hdr,ts)\n", " tbl_cols[id] = hdr\n", " tbl_types[id] = ts\n", "\n", "q_s = []\n", "a_s = []\n", "\n", "with open('data/train.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['table_id']\n", " s = tbl_str[id]\n", " qst = js['question']\n", " nl = s + '\\nQ: ' + qst + '\\nA: '\n", " q_s.append(nl)\n", "\n", " sql = js['sql']\n", " a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n", " a = a + \"\\nEND\\n\"\n", " a_s.append(a)\n", "\n", "M = len(q_s)\n", "\n", "data_txt = [q_s[i] + a_s[i] for i in range(M)]\n", "\n", "for i in range(5):\n", " j = random.randint(0,M-1)\n", " print()\n", " print(data_txt[j]) \n", " \n", " " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "toks = [tokenizer(s) for s in data_txt]\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "92\n", " 0\n", "count 56355.000000\n", "mean 101.219519\n", "std 21.740325\n", "min 63.000000\n", "25% 87.500000\n", "50% 97.000000\n", "75% 109.000000\n", "max 461.000000\n", "32084\n" ] } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "print(len(toks[0].input_ids))\n", "lens = np.array([len(tok.input_ids) for tok in toks])\n", "print(pd.DataFrame(lens).describe())\n", "\n", "z = zip(q_s,lens)\n", "q_red = [a for a,b in z if b < 100]\n", "z = zip(a_s,lens)\n", "a_red = [a for a,b in z if b < 100]\n", "\n", "data_red = [q_red[i] + a_red[i] for i in range(len(q_red))]\n", "print(len(data_red))\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "01debeeb68bb40a7b83031e88c5ace1e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/32084 [00:00\n", " \n", " \n", " [250/250 3:49:26, Epoch 0/1]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
12.748800
22.699100
32.670200
42.600500
52.560100
62.556800
72.498100
82.515400
92.436100
102.411700
112.346400
122.276300
132.238000
142.189100
152.109200
162.058000
171.983900
181.928600
191.824100
201.794700
211.681200
221.598900
231.562000
241.527200
251.518700
261.493100
271.500500
281.464000
291.386900
301.373400
311.362200
321.360800
331.321000
341.310500
351.302600
361.256100
371.252500
381.202300
391.249100
401.188600
411.203200
421.150000
431.182000
441.192300
451.133100
461.119600
471.097000
481.142100
491.117200
501.129200
511.087300
521.098700
531.135400
541.071700
551.087300
561.051400
571.068300
581.092500
591.068600
601.072800
611.074000
621.060400
631.065800
641.075900
651.059500
661.039600
671.051400
681.049500
691.023800
701.071900
711.051000
721.034700
731.041600
741.030900
751.010800
761.019800
771.005000
781.043800
791.009200
801.017100
811.044600
821.022600
831.011400
840.996600
851.029900
860.988200
871.005600
880.986600
891.025300
901.012500
910.988100
921.001800
930.987100
941.017600
950.998500
960.966600
970.983700
980.961800
990.969000
1000.989200
1010.956400
1020.976000
1031.000100
1041.001500
1050.995900
1060.989700
1070.965700
1080.968400
1091.019600
1101.000100
1110.978500
1120.978900
1130.952600
1140.975400
1150.989400
1160.968500
1170.960100
1180.979100
1190.955100
1200.934800
1210.943600
1220.976700
1230.998700
1240.930500
1250.953500
1260.978000
1270.967300
1280.929400
1290.963100
1300.961500
1310.978500
1320.937200
1330.953400
1340.962000
1350.950700
1360.925100
1370.958800
1380.926200
1390.930600
1400.968900
1410.970400
1420.927100
1430.911800
1440.953200
1450.907100
1460.935900
1470.970600
1480.920400
1490.930200
1500.926700
1510.913400
1520.926800
1530.967200
1540.939500
1550.910600
1560.926400
1570.935400
1580.967700
1590.899000
1600.916600
1610.961600
1620.898200
1630.944600
1640.935700
1650.922500
1660.897600
1670.968600
1680.927400
1690.910900
1700.904700
1710.899800
1720.896400
1730.862100
1740.909100
1750.903200
1760.958600
1770.902500
1780.894900
1790.937900
1800.900700
1810.922300
1820.939300
1830.932600
1840.913300
1850.941700
1860.886300
1870.918000
1880.884000
1890.947400
1900.894500
1910.929300
1920.877300
1930.894300
1940.867800
1950.913500
1960.908100
1970.931200
1980.911000
1990.941800
2000.913000
2010.921800
2020.921700
2030.914500
2040.910500
2050.906600
2060.915100
2070.881600
2080.884700
2090.902900
2100.882600
2110.891000
2120.914400
2130.930400
2140.891100
2150.859300
2160.891800
2170.873000
2180.925900
2190.905700
2200.921200
2210.890200
2220.915800
2230.887300
2240.898300
2250.865600
2260.873900
2270.904800
2280.917900
2290.923400
2300.939700
2310.913400
2320.873100
2330.896700
2340.892100
2350.902100
2360.927200
2370.912900
2380.872900
2390.904700
2400.879600
2410.879800
2420.908800
2430.909800
2440.838400
2450.889200
2460.912900
2470.879700
2480.910700
2490.845400
2500.882200

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer = transformers.Trainer(\n", " model = model,\n", " train_dataset = data,\n", " args = targs,\n", " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n", ")\n", "trainer.train(resume_from_checkpoint=False)\n", "model.save_pretrained('sqllama-out2')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n", " warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "from model\n", "table: 1-12028543-3\n", "columns: Season,Cup FinalDate,WinningTeam,Score,LosingTeam,Location,Cup Final Attendance\n", "Q: Who was the winning team in the 1989 season?\n", "A: SELECT WinningTeam FROM 1-12028543-3 WHERE Season = '1989'\n", "END\n", "END\n", "END\n", "END\n", "\n", "expected answer\n", "SELECT WinningTeam FROM 1-12028543-3 WHERE Season = '1989'\n", "END\n", "\n", "from model\n", "table: 2-18096431-5\n", "columns: Place,Player,Country,Score,To par\n", "Q: What is To par, when Country is \"United States\", and when Player is \"Mark Brooks\"?\n", "A: 18-1\n", "END\n", "\n", "\n", "expected answer\n", "SELECT To par FROM 2-18096431-5 WHERE Country = 'united states' AND Player = 'mark brooks'\n", "END\n", "\n", "from model\n", "table: 2-10701914-2\n", "columns: Home team,Home team score,Away team,Away team score,Venue,Crowd,Date\n", "Q: What home team played at the western oval?\n", "A: Western Bulldogs\n", "END\n", "END\n", "END\n", "END\n", "END\n", "END\n", "END\n", "END\n", "END\n", "END\n", "END\n", "END\n", "END\n", "END\n", "END\n", "END\n", "END\n", "\n", "\n", "expected answer\n", "SELECT Home team FROM 2-10701914-2 WHERE Venue = 'western oval'\n", "END\n", "\n", "from model\n", "table: 1-29598261-1\n", "columns: Name,Number,Position,Height,Weight,Year,Hometown,Last School/College\n", "Q: what is the school for chris mcnamara?\n", "A: SELECT Last School/College FROM 1-29598261-1 WHERE Name = 'chris mcnamara'\n", "END\n", "END\n", "END\n", "END\n", "\n", "\n", "expected answer\n", "SELECT Last School/College FROM 1-29598261-1 WHERE Name = 'Chris McNamara'\n", "END\n", "\n", "from model\n", "table: 1-27722408-11\n", "columns: Game,Date,Team,Score,High points,High rebounds,High assists,Location Attendance,Record\n", "Q: Who had the most assists and how many did they have on April 8?\n", "A: SELECT High assists FROM 1-27722408-11 WHERE Date = 'april 8'\n", "END\n", "\n", "\n", "expected answer\n", "SELECT High assists FROM 1-27722408-11 WHERE Date = 'April 8'\n", "END\n", "\n", "from model\n", "table: 1-21378339-5\n", "columns: Draw,Song,Artist,Panel Points,Televotes,Televote Points,Score,Placing\n", "Q: Name the number of artists for panel points being 5\n", "A: SELECT COUNT Artist FROM 1-21378339-5 WHERE Panel Points = 5\n", "END\n", "END\n", "END\n", "END\n", "END\n", "\n", "expected answer\n", "SELECT COUNT Artist FROM 1-21378339-5 WHERE Panel Points = 5\n", "END\n", "\n", "from model\n", "table: 2-11545282-17\n", "columns: Player,Nationality,Position,Years for Jazz,School/Club Team\n", "Q: What position does Michael Ruffin play?\n", "A: SELECT Position FROM 2-11545282-17 WHERE Player = 'michael ruffin'\n", "END\n", "END\n", "END\n", "END\n", "END\n", "END\n", "END\n", "END\n", "\n", "\n", "expected answer\n", "SELECT Position FROM 2-11545282-17 WHERE Player = 'michael ruffin'\n", "END\n", "\n", "from model\n", "table: 1-17801022-1\n", "columns: Year,Date,Driver,Manufacturer,Laps,Miles (km),Race Time,Average Speed (mph)\n", "Q: What manufacturer won the race on November 2?\n", "A: SELECT Manufacturer FROM 1-17801022-1 WHERE Date = 'november 2'\n", "END\n", "END\n", "END\n", "\n", "expected answer\n", "SELECT Manufacturer FROM 1-17801022-1 WHERE Date = 'November 2'\n", "END\n", "\n", "from model\n", "table: 2-10806592-14\n", "columns: Home team,Home team score,Away team,Away team score,Venue,Crowd,Date\n", "Q: What was the away score when the home team was Melbourne?\n", "A: SELECT Away team score FROM 2-10806592-14 WHERE Home team = 'melbourne'\n", "END\n", "END\n", "END\n", "\n", "\n", "expected answer\n", "SELECT Away team score FROM 2-10806592-14 WHERE Home team = 'melbourne'\n", "END\n", "\n", "from model\n", "table: 2-17978030-6\n", "columns: Date,Time,Score,Set 1,Set 2,Set 3,Total\n", "Q: What is the score when the set 3 is 26–28?\n", "A: SELECT Score FROM 2-17978030-6 WHERE Set 3 = '26–28'\n", "END\n", "END\n", "Q: What\n", "\n", "expected answer\n", "SELECT Score FROM 2-17978030-6 WHERE Set 3 = '26–28'\n", "END\n", "\n" ] } ], "source": [ "def get_query(q):\n", " \n", " toks = tokenizer(q , return_tensors='pt')\n", " ctoks = toks.input_ids.to('cuda')\n", " gen = model.generate(ctoks, max_length=100)\n", " return tokenizer.decode(gen[0])\n", "\n", "M = len(q_red)\n", "\n", "for _ in range(10):\n", " j = random.randint(0,M-1)\n", " qs = q_red[j]\n", " a = a_red[j]\n", "\n", " ma = get_query(qs)\n", "\n", " #print(qs)\n", " print('from model')\n", " print(ma)\n", " print()\n", " print('expected answer')\n", " print(a)\n" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "6a381460736e8a0eabfb35eafae436ba15c06439de44e28b965ea473bd8dda90" } } }, "nbformat": 4, "nbformat_minor": 2 }