{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "torch.cuda.is_available()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Check that we can load the tokenizer and the model. The first time this runs it will take a while. The files go under ~/.cache/huggingface" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/cuda_setup/main.py:136: UserWarning: /opt/conda did not contain libcudart.so as expected! Searching further paths...\n", " warn(msg)\n", "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n", "The class this function is called from is 'LlamaTokenizer'.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "===================================BUG REPORT===================================\n", "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", "================================================================================\n", "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so\n", "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n", "CUDA SETUP: Detected CUDA version 113\n", "CUDA SETUP: Loading binary /home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5616e727844a4f0b9efaff97aa2f9d75", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/33 [00:00 24\n", "\n", "Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n", "### Question: What is the number of capacity at somerset park?\n", "### Input: Table 1-11206787-5 has columns Team (text),Stadium (text),Capacity (real),Highest (real),Lowest (real),Average (real). \n", "### Answer: SELECT COUNT Capacity FROM 1-11206787-5 WHERE Stadium = 'Somerset Park'\n", "\n", "Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n", "### Question: What is the number & name with an Undergoing overhaul, restoration or repairs date?\n", "### Input: Table 2-11913905-6 has columns Number & Name (text),Description (text),Livery (text),Owner(s) (text),Date (text). \n", "### Answer: SELECT Number & Name FROM 2-11913905-6 WHERE Date = 'undergoing overhaul, restoration or repairs'\n", "\n", "Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n", "### Question: What year did Orlando have a School/Club team in Clemson?\n", "### Input: Table 2-15621965-7 has columns Player (text),Nationality (text),Position (text),Years in Orlando (text),School/Club Team (text). \n", "### Answer: SELECT Years in Orlando FROM 2-15621965-7 WHERE School/Club Team = 'clemson'\n", "\n", "Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n", "### Question: How many Deaths have a Fate of damaged, and a Tonnage (GRT) smaller than 4,917?\n", "### Input: Table 2-18914307-1 has columns Date (text),Ship Name (text),Flag (text),Tonnage ( GRT ) (real),Fate (text),Deaths (real). \n", "### Answer: SELECT COUNT Deaths FROM 2-18914307-1 WHERE Fate = 'damaged' AND Tonnage ( GRT ) < 4,917\n" ] } ], "source": [ "import random\n", "import json\n", "\n", "# defined by WikiSQL\n", "\n", "agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']\n", "cond_ops = ['=', '>', '<', 'OP']\n", "syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']\n", "\n", "def fix_repr(d,cols,types,tid):\n", " sel_index=d['sel'] \n", " agg_index=d['agg']\n", " conditions=d['conds']\n", " col = cols[sel_index]\n", " rep = 'SELECT {agg} {sel} FROM {tid}'.format(\n", " agg=agg_ops[agg_index],\n", " sel=col,\n", " tid=tid\n", " )\n", " if conditions:\n", " cs = []\n", " for i, o, v in conditions:\n", " #print(i,cols)\n", " nm = cols[i]\n", " op = cond_ops[o]\n", " \n", " if types[i] in ['text']:\n", " val = f\"\\'{v}\\'\"\n", " else:\n", " val = v\n", " cs.append(f'{nm} {op} {val}')\n", " #print(cs)\n", "\n", " rep += ' WHERE ' + ' AND '.join(cs)\n", " \n", " return rep\n", "\n", "tbl_cols = {}\n", "tbl_types = {}\n", "tbl_str = {}\n", "\n", "prefix = 'Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.'\n", "\n", "def tbl_def_to_string(id, header, types):\n", " ht = [f'{header[i]} ({types[i]})' for i in range(len(header))]\n", " s = f'\\n### Input: Table {id} has columns ' + ','.join(ht) + '. '\n", " return s\n", "\n", "with open('data/train.tables.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['id']\n", " hdr = js['header']\n", " ts = js['types']\n", " tbl_str[id] = tbl_def_to_string(id,hdr,ts)\n", " tbl_cols[id] = hdr\n", " tbl_types[id] = ts\n", "\n", "\n", "nl_q = []\n", "sql_a = []\n", "\n", "with open('data/train.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['table_id']\n", " s = tbl_str[id]\n", " qst = js['question']\n", " nl = prefix + \"\\n### Question: \" + qst + s\n", " nl_q.append(nl)\n", "\n", " sql = js['sql']\n", " a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n", " a = '\\n### Answer: ' + a\n", " sql_a.append(a)\n", "\n", "\n", "M = len(nl_q)\n", "\n", "data_txt = [nl_q[i] + sql_a[i] for i in range(len(nl_q))]\n", "\n", "for i in range(5):\n", " j = random.randint(0,M-1)\n", " print()\n", " print(data_txt[j]) \n", " \n", " " ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Set up the details for the model." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "431b36f60a3940cf8646e1bea4324745", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/56355 [00:00\n", " \n", " \n", " [440/440 11:19:07, Epoch 0/1]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
12.517200
22.482300
32.444100
42.456500
52.441400
62.484600
72.424000
82.477900
92.429700
102.436000
112.422000
122.408800
132.402900
142.424500
152.421800
162.424100
172.404000
182.386900
192.414400
202.370600
212.382500
222.350700
232.385700
242.350400
252.354900
262.345400
272.373000
282.343200
292.374300
302.325000
312.352000
322.344600
332.360000
342.347400
352.346700
362.329000
372.314600
382.306000
392.292600
402.333800
412.311500
422.308300
432.287400
442.314100
452.280400
462.261300
472.274200
482.246900
492.257100
502.274500
512.245500
522.250700
532.296600
542.261000
552.223800
562.244000
572.228500
582.229100
592.162300
602.238000
612.246000
622.184800
632.195000
642.199500
652.180000
662.179800
672.149700
682.177000
692.156600
702.193400
712.163400
722.147400
732.134700
742.133200
752.118000
762.139000
772.102000
782.109100
792.099000
802.097500
812.073200
822.055200
832.078100
842.104800
852.061100
862.066500
872.073500
882.010500
892.045700
902.026700
912.046500
922.015300
932.019100
942.008600
951.961000
961.974300
971.991700
981.984700
991.975900
1001.963900
1011.934300
1021.990400
1031.914900
1041.956100
1051.943400
1061.931000
1071.919000
1081.912800
1091.920400
1101.878300
1111.890800
1121.881900
1131.885400
1141.908400
1151.871200
1161.900000
1171.888000
1181.875100
1191.855000
1201.852100
1211.851200
1221.821800
1231.853000
1241.854700
1251.806900
1261.845300
1271.797800
1281.795300
1291.799500
1301.853900
1311.780100
1321.789400
1331.776700
1341.747300
1351.753700
1361.761300
1371.725500
1381.710800
1391.733500
1401.727000
1411.744300
1421.728900
1431.725100
1441.708000
1451.709000
1461.704600
1471.684600
1481.676100
1491.682800
1501.669900
1511.636400
1521.671500
1531.673200
1541.644300
1551.620800
1561.617500
1571.647700
1581.629300
1591.608800
1601.633000
1611.618200
1621.634300
1631.588400
1641.581100
1651.584500
1661.594800
1671.563800
1681.576900
1691.546300
1701.569800
1711.592300
1721.537800
1731.519200
1741.512100
1751.581500
1761.534500
1771.509400
1781.521300
1791.528500
1801.494300
1811.495000
1821.499700
1831.461300
1841.469200
1851.495200
1861.467400
1871.437000
1881.463000
1891.437900
1901.467400
1911.472300
1921.434000
1931.411500
1941.432500
1951.459800
1961.431900
1971.456200
1981.394800
1991.422700
2001.412800
2011.413800
2021.380000
2031.407400
2041.406200
2051.396100
2061.407100
2071.379600
2081.360600
2091.395100
2101.352500
2111.358900
2121.369100
2131.342600
2141.358900
2151.320300
2161.355700
2171.315700
2181.348800
2191.319800
2201.336500
2211.339600
2221.319500
2231.319600
2241.330200
2251.271700
2261.317300
2271.287400
2281.283300
2291.280500
2301.274200
2311.297000
2321.266400
2331.253100
2341.273100
2351.293300
2361.293000
2371.273500
2381.253100
2391.257700
2401.232500
2411.233100
2421.226000
2431.218400
2441.222800
2451.232100
2461.214800
2471.205700
2481.228400
2491.202600
2501.207700
2511.205800
2521.198400
2531.207800
2541.198600
2551.201700
2561.195500
2571.190500
2581.197100
2591.165100
2601.173200
2611.163400
2621.191500
2631.173700
2641.134400
2651.165500
2661.134800
2671.149500
2681.173100
2691.137000
2701.171200
2711.120600
2721.147600
2731.128300
2741.150300
2751.147700
2761.150200
2771.106900
2781.145400
2791.117300
2801.121900
2811.139400
2821.109100
2831.142100
2841.117300
2851.104200
2861.134200
2871.100400
2881.092100
2891.120500
2901.088100
2911.128600
2921.105400
2931.094000
2941.108900
2951.073100
2961.100900
2971.092400
2981.090300
2991.079400
3001.090300
3011.086100
3021.080300
3031.075600
3041.075900
3051.092200
3061.070600
3071.068800
3081.071300
3091.073900
3101.055400
3111.067900
3121.041000
3131.048600
3141.072600
3151.058800
3161.039000
3171.072300
3181.056600
3191.035100
3201.052800
3211.046700
3221.073400
3231.054000
3241.077100
3251.035200
3261.027700
3271.060000
3281.048900
3291.040000
3301.026900
3311.049300
3321.017100
3330.996200
3341.006400
3351.026700
3361.073700
3371.039200
3381.041100
3391.054300
3401.013500
3411.024900
3421.003300
3430.993400
3441.037300
3451.009300
3461.030400
3471.001400
3481.012100
3491.027300
3501.012700
3511.013400
3521.004400
3531.024800
3540.990700
3551.048600
3560.992700
3570.991800
3580.985300
3591.019100
3601.007300
3611.025500
3620.999100
3630.997900
3641.013300
3651.014700
3661.037700
3670.992400
3680.988800
3690.993900
3700.999500
3710.973000
3720.972200
3730.989200
3740.994500
3750.995800
3760.992000
3770.977800
3780.975700
3790.973700
3800.986200
3811.008000
3820.954100
3831.015900
3841.008200
3850.974700
3860.987500
3870.993700
3880.999200
3891.000700
3900.978600
3910.956200
3921.001600
3930.971300
3940.965800
3950.981000
3960.965400
3970.974200
3980.970700
3990.953500
4000.979700
4010.957700
4020.984600
4031.015600
4040.976800
4050.969100
4060.974200
4070.983300
4080.974300
4090.980600
4100.986300
4110.968100
4120.980500
4130.976200
4140.987300
4150.971600
4160.985200
4170.989800
4180.972000
4190.971100
4200.988800
4210.965600
4221.020400
4230.978000
4240.987800
4250.953700
4260.990400
4270.982900
4280.989100
4290.983800
4300.981500
4310.966900
4320.967300
4330.999400
4340.973100
4350.980500
4360.995500
4370.960300
4380.953700
4390.993600
4400.965100

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "\n", "\n", "trainer = transformers.Trainer(\n", " model = model,\n", " train_dataset = data,\n", " args = targs,\n", " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n", ")\n", "trainer.train(resume_from_checkpoint=False)\n", "model.save_pretrained('sqllama-out')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/generation/utils.py:1220: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)\n", " \"You have modified the pretrained model configuration to control generation. This is a\"\n", "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n", " warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n", "### Question: What county has a CERCLIS ID of scd037405362?\n", "### Input: Table 2-11960788-1 has columns CERCLIS ID (text),Name (text),County (text),Proposed (text),Listed (text). \n", "### Answer: \n", "Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n", "### Question: What county has a CERCLIS ID of scd037405362?\n", "### Input: Table 2-11960788-1 has columns CERCLIS ID (text),Name (text),County (text),Proposed (text),Listed (text). \n", "### Answer: SELECT County FROM 2-11960788-1 WHERE CERCLIS ID = 'scd037405362' \n", "### Question: What county has a CERCLIS ID of scd037405362?\n", "### Input: Table 2-11960788-1 has columns CERCLIS ID (text),Name (text),County (text),Proposed (text),Listed (text). \n", "### Answer: SELECT County FROM 2-11960788-1 WHERE CERCLIS ID\n", "\n", "### Answer: SELECT County FROM 2-11960788-1 WHERE CERCLIS ID = 'scd037405362'\n" ] } ], "source": [ "def get_query(q):\n", " \n", " toks = tokenizer(q , return_tensors='pt')\n", " ctoks = toks.input_ids.to('cuda')\n", " gen = model.generate(ctoks, max_length=256)\n", " return tokenizer.decode(gen[0])\n", "\n", "M = len(nl_q)\n", "j = random.randint(0,M-1)\n", "qs = nl_q[j] + '\\n### Answer: '\n", "a = sql_a[j]\n", "\n", "ma = get_query(qs)\n", "\n", "#print(qs)\n", "print('from model')\n", "print(ma)\n", "print('expected answer')\n", "print(a)\n" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "6a381460736e8a0eabfb35eafae436ba15c06439de44e28b965ea473bd8dda90" } } }, "nbformat": 4, "nbformat_minor": 2 }