wikty
commited on
Commit
•
e1a5784
1
Parent(s):
ead9ec0
update README
Browse files
README.md
CHANGED
@@ -1,3 +1,70 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
---
|
4 |
+
|
5 |
+
## 简介
|
6 |
+
|
7 |
+
这是一款根据自然语言生成 SQL 的模型(NL2SQL/Text2SQL),是我们自研众多 NL2SQL 模型中最为基础的一版,其它高级版模型后续将陆续进行开源。
|
8 |
+
|
9 |
+
该模型基于 BART 架构,我们将 NL2SQL 问题建模为类似机器翻译的 Seq2Seq 形式,该模型的优势特点:参数规模较小、但 SQL 生成准确性也较高。
|
10 |
+
|
11 |
+
## 用法
|
12 |
+
|
13 |
+
NL2SQL 任务中输入参数含有用户查询文本+数据库表信息,目前按照以下格式拼接模型的输入文本:
|
14 |
+
|
15 |
+
```
|
16 |
+
Question: 名人堂一共有多少球员 <sep> Tables: hall_of_fame: player_id, yearid, votedby, ballots, needed, votes, inducted, category, needed_note ; player_award: player_id, award_id, year, league_id, tie, notes <sep>
|
17 |
+
```
|
18 |
+
|
19 |
+
具体使用方法参考以下示例:
|
20 |
+
|
21 |
+
```python
|
22 |
+
import torch
|
23 |
+
from transformers import AutoModelForSeq2SeqLM, MBartForConditionalGeneration, AutoTokenizer
|
24 |
+
|
25 |
+
device = 'cuda'
|
26 |
+
model_path = 'DMetaSoul/nl2sql-chinese-basic'
|
27 |
+
sampling = False
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, src_lang='zh_CN')
|
29 |
+
#model = MBartForConditionalGeneration.from_pretrained(model_path)
|
30 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
|
31 |
+
model = model.half()
|
32 |
+
model.to(device)
|
33 |
+
|
34 |
+
|
35 |
+
input_texts = [
|
36 |
+
"Question: 所有章节的名称和描述是什么? <sep> Tables: sections: section id , course id , section name , section description , other details <sep>",
|
37 |
+
"Question: 名人堂一共有多少球员 <sep> Tables: hall_of_fame: player_id, yearid, votedby, ballots, needed, votes, inducted, category, needed_note ; player_award: player_id, award_id, year, league_id, tie, notes ; player_award_vote: award_id, year, league_id, player_id, points_won, points_max, votes_first ; salary: year, team_id, league_id, player_id, salary ; player: player_id, birth_year, birth_month, birth_day, birth_country, birth_state, birth_city, death_year, death_month, death_day, death_country, death_state, death_city, name_first, name_last, name_given, weight <sep>"
|
38 |
+
]
|
39 |
+
inputs = tokenizer(input_texts, max_length=512, return_tensors="pt",
|
40 |
+
padding=True, truncation=True)
|
41 |
+
inputs = {k:v.to(device) for k,v in inputs.items() if k not in ["token_type_ids"]}
|
42 |
+
|
43 |
+
with torch.no_grad():
|
44 |
+
if sampling:
|
45 |
+
outputs = model.generate(**inputs, do_sample=True, top_k=50, top_p=0.95,
|
46 |
+
temperature=1.0, num_return_sequences=1,
|
47 |
+
max_length=512, return_dict_in_generate=True, output_scores=True)
|
48 |
+
else:
|
49 |
+
outputs = model.generate(**inputs, num_beams=4, num_return_sequences=1,
|
50 |
+
max_length=512, return_dict_in_generate=True, output_scores=True)
|
51 |
+
|
52 |
+
output_ids = outputs.sequences
|
53 |
+
results = tokenizer.batch_decode(output_ids, skip_special_tokens=True,
|
54 |
+
clean_up_tokenization_spaces=True)
|
55 |
+
|
56 |
+
for question, sql in zip(input_texts, results):
|
57 |
+
print(question)
|
58 |
+
print('SQL: {}'.format(sql))
|
59 |
+
print()
|
60 |
+
```
|
61 |
+
|
62 |
+
输入结果如下:
|
63 |
+
|
64 |
+
```
|
65 |
+
Question: 所有章节的名称和描述是什么? <sep> Tables: sections: section id , course id , section name , section description , other details <sep>
|
66 |
+
SQL: SELECT section name, section description FROM sections
|
67 |
+
|
68 |
+
Question: 名人堂一共有多少球员 <sep> Tables: hall_of_fame: player_id, yearid, votedby, ballots, needed, votes, inducted, category, needed_note ; player_award: player_id, award_id, year, league_id, tie, notes ; player_award_vote: award_id, year, league_id, player_id, points_won, points_max, votes_first ; salary: year, team_id, league_id, player_id, salary ; player: player_id, birth_year, birth_month, birth_day, birth_country, birth_state, birth_city, death_year, death_month, death_day, death_country, death_state, death_city, name_first, name_last, name_given, weight <sep>
|
69 |
+
SQL: SELECT count(*) FROM hall_of_fame
|
70 |
+
```
|