|
--- |
|
datasets: |
|
- wikisql |
|
metrics: |
|
- accuracy |
|
pipeline_tag: text-classification |
|
tags: |
|
- code |
|
--- |
|
|
|
Base Model:t5-small |
|
|
|
#Training Result |
|
|
|
[17610/17610 1:32:31, Epoch 9/10] |
|
Step Training Loss Validation Loss |
|
1000 2.682400 0.829368 |
|
2000 0.914000 0.568155 |
|
3000 0.707700 0.465733 |
|
4000 0.613500 0.408758 |
|
5000 0.557300 0.374811 |
|
6000 0.515800 0.350752 |
|
7000 0.487000 0.331517 |
|
8000 0.466100 0.319071 |
|
9000 0.449400 0.309488 |
|
10000 0.438800 0.301829 |
|
11000 0.430000 0.296482 |
|
12000 0.420200 0.292672 |
|
13000 0.418200 0.290445 |
|
14000 0.413400 0.288662 |
|
15000 0.410100 0.287757 |
|
16000 0.412600 0.287280 |
|
17000 0.410000 0.287134 |
|
|
|
question: what is id with name jui and age equal 25 |
|
table: ['id', 'name', 'age'] |
|
SELECT ID FROM table WHEREname = jui AND age equal 25 |
|
|
|
#Copy below piece of code to your notebook to use the model |
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
# Load the tokenizer |
|
tokenizer = AutoTokenizer.from_pretrained("DebeshSahoo/text2sql-finetune") |
|
|
|
# Load the model |
|
model = AutoModelForSeq2SeqLM.from_pretrained("DebeshSahoo/text2sql-finetune") |
|
|
|
# Rest of the code for preparing input, generating predictions, and decoding the output... |
|
|
|
|
|
from typing import List |
|
|
|
table_prefix = "table:" |
|
question_prefix = "question:" |
|
|
|
def prepare_input(question: str, table: List[str]): |
|
print("question:", question) |
|
print("table:", table) |
|
join_table = ",".join(table) |
|
inputs = f"{question_prefix} {question} {table_prefix} {join_table}" |
|
input_ids = tokenizer(inputs, max_length=700, return_tensors="pt").input_ids |
|
return input_ids |
|
|
|
def inference(question: str, table: List[str]) -> str: |
|
input_data = prepare_input(question=question, table=table) |
|
input_data = input_data.to(model.device) |
|
outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512) |
|
result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True) |
|
return result |
|
|
|
test_id = 1000 |
|
print("model result:", inference(dataset["test"][test_id]["question"], dataset["test"][test_id]["table"]["header"])) |
|
print("real result:", dataset["test"][test_id]["sql"]["human_readable"]) |
|
|
|
inference("what is id with name jui and age equal 25", ["id","name", "age"]) |