text2sql-finetune / README.md
DebeshSahoo's picture
update
d773610
---
datasets:
- wikisql
metrics:
- accuracy
pipeline_tag: text-classification
tags:
- code
---
Base Model:t5-small
#Training Result
[17610/17610 1:32:31, Epoch 9/10]
Step Training Loss Validation Loss
1000 2.682400 0.829368
2000 0.914000 0.568155
3000 0.707700 0.465733
4000 0.613500 0.408758
5000 0.557300 0.374811
6000 0.515800 0.350752
7000 0.487000 0.331517
8000 0.466100 0.319071
9000 0.449400 0.309488
10000 0.438800 0.301829
11000 0.430000 0.296482
12000 0.420200 0.292672
13000 0.418200 0.290445
14000 0.413400 0.288662
15000 0.410100 0.287757
16000 0.412600 0.287280
17000 0.410000 0.287134
question: what is id with name jui and age equal 25
table: ['id', 'name', 'age']
SELECT ID FROM table WHEREname = jui AND age equal 25
#Copy below piece of code to your notebook to use the model
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("DebeshSahoo/text2sql-finetune")
# Load the model
model = AutoModelForSeq2SeqLM.from_pretrained("DebeshSahoo/text2sql-finetune")
# Rest of the code for preparing input, generating predictions, and decoding the output...
from typing import List
table_prefix = "table:"
question_prefix = "question:"
def prepare_input(question: str, table: List[str]):
print("question:", question)
print("table:", table)
join_table = ",".join(table)
inputs = f"{question_prefix} {question} {table_prefix} {join_table}"
input_ids = tokenizer(inputs, max_length=700, return_tensors="pt").input_ids
return input_ids
def inference(question: str, table: List[str]) -> str:
input_data = prepare_input(question=question, table=table)
input_data = input_data.to(model.device)
outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
return result
test_id = 1000
print("model result:", inference(dataset["test"][test_id]["question"], dataset["test"][test_id]["table"]["header"]))
print("real result:", dataset["test"][test_id]["sql"]["human_readable"])
inference("what is id with name jui and age equal 25", ["id","name", "age"])