File size: 1,493 Bytes
5587f01
 
355d40d
 
 
4e95d94
 
 
 
 
a046d77
cbf84c9
a046d77
cbf84c9
a046d77
cbf84c9
 
4e95d94
 
5587f01
355d40d
4e95d94
 
 
355d40d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
---
license: apache-2.0
language:
- en
pipeline_tag: text2text-generation
library_name: transformers
tags:
- text-generation-inference
widget:
- text: >
    Given a SQL table named 'price_data' with the following columns:
    
    Transaction_ID, Platform, Product_ID, User_ID, Transaction_Amount
    
    Construct a SQL query to answer the following question:
    
    Q: How many rows are there

  example_title: "How many rows are there?"
---

A text2sql T5 model, finetuned from Flan-t5-base. Code: [Link](https://github.com/kevinng77/chat-table-t5/blob/master/prompt.py)
A further finetuning will significantly increase the performance of Flan-t5 model on Text-to-SQL tasks.


## Inference Example:


```python
from transformers import T5Tokenizer, T5ForConditionalGeneration, pipeline

table_columns = "Transaction_ID, Platform, Product_ID, User_ID, Transaction_Amount, Region, Transaction_Time, Transaction_Unit, User_Comments"

table_name = "my_data"

PROMPT_INPUT = f"""
Given a SQL table named '{table_name}' with the following columns:
{table_columns}

Construct a SQL query to answer the following question:
Q: {{question}}.
"""

model_id = "kevinng77/chat-table-flan-t5"
tokenizer = T5Tokenizer.from_pretrained(model_id)
model = T5ForConditionalGeneration.from_pretrained(model_id)

input_text = PROMPT_INPUT.format_map({"question": "How many rows are there in the table?"})

pipe = pipeline(
    "text2text-generation",
    model=model, tokenizer=tokenizer, max_length=512
)
```