wikty commited on
Commit
e1a5784
1 Parent(s): ead9ec0

update README

Browse files
Files changed (1) hide show
  1. README.md +67 -0
README.md CHANGED
@@ -1,3 +1,70 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ ## 简介
6
+
7
+ 这是一款根据自然语言生成 SQL 的模型(NL2SQL/Text2SQL),是我们自研众多 NL2SQL 模型中最为基础的一版,其它高级版模型后续将陆续进行开源。
8
+
9
+ 该模型基于 BART 架构,我们将 NL2SQL 问题建模为类似机器翻译的 Seq2Seq 形式,该模型的优势特点:参数规模较小、但 SQL 生成准确性也较高。
10
+
11
+ ## 用法
12
+
13
+ NL2SQL 任务中输入参数含有用户查询文本+数据库表信息,目前按照以下格式拼接模型的输入文本:
14
+
15
+ ```
16
+ Question: 名人堂一共有多少球员 <sep> Tables: hall_of_fame: player_id, yearid, votedby, ballots, needed, votes, inducted, category, needed_note ; player_award: player_id, award_id, year, league_id, tie, notes <sep>
17
+ ```
18
+
19
+ 具体使用方法参考以下示例:
20
+
21
+ ```python
22
+ import torch
23
+ from transformers import AutoModelForSeq2SeqLM, MBartForConditionalGeneration, AutoTokenizer
24
+
25
+ device = 'cuda'
26
+ model_path = 'DMetaSoul/nl2sql-chinese-basic'
27
+ sampling = False
28
+ tokenizer = AutoTokenizer.from_pretrained(model_path, src_lang='zh_CN')
29
+ #model = MBartForConditionalGeneration.from_pretrained(model_path)
30
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
31
+ model = model.half()
32
+ model.to(device)
33
+
34
+
35
+ input_texts = [
36
+ "Question: 所有章节的名称和描述是什么? <sep> Tables: sections: section id , course id , section name , section description , other details <sep>",
37
+ "Question: 名人堂一共有多少球员 <sep> Tables: hall_of_fame: player_id, yearid, votedby, ballots, needed, votes, inducted, category, needed_note ; player_award: player_id, award_id, year, league_id, tie, notes ; player_award_vote: award_id, year, league_id, player_id, points_won, points_max, votes_first ; salary: year, team_id, league_id, player_id, salary ; player: player_id, birth_year, birth_month, birth_day, birth_country, birth_state, birth_city, death_year, death_month, death_day, death_country, death_state, death_city, name_first, name_last, name_given, weight <sep>"
38
+ ]
39
+ inputs = tokenizer(input_texts, max_length=512, return_tensors="pt",
40
+ padding=True, truncation=True)
41
+ inputs = {k:v.to(device) for k,v in inputs.items() if k not in ["token_type_ids"]}
42
+
43
+ with torch.no_grad():
44
+ if sampling:
45
+ outputs = model.generate(**inputs, do_sample=True, top_k=50, top_p=0.95,
46
+ temperature=1.0, num_return_sequences=1,
47
+ max_length=512, return_dict_in_generate=True, output_scores=True)
48
+ else:
49
+ outputs = model.generate(**inputs, num_beams=4, num_return_sequences=1,
50
+ max_length=512, return_dict_in_generate=True, output_scores=True)
51
+
52
+ output_ids = outputs.sequences
53
+ results = tokenizer.batch_decode(output_ids, skip_special_tokens=True,
54
+ clean_up_tokenization_spaces=True)
55
+
56
+ for question, sql in zip(input_texts, results):
57
+ print(question)
58
+ print('SQL: {}'.format(sql))
59
+ print()
60
+ ```
61
+
62
+ 输入结果如下:
63
+
64
+ ```
65
+ Question: 所有章节的名称和描述是什么? <sep> Tables: sections: section id , course id , section name , section description , other details <sep>
66
+ SQL: SELECT section name, section description FROM sections
67
+
68
+ Question: 名人堂一共有多少球员 <sep> Tables: hall_of_fame: player_id, yearid, votedby, ballots, needed, votes, inducted, category, needed_note ; player_award: player_id, award_id, year, league_id, tie, notes ; player_award_vote: award_id, year, league_id, player_id, points_won, points_max, votes_first ; salary: year, team_id, league_id, player_id, salary ; player: player_id, birth_year, birth_month, birth_day, birth_country, birth_state, birth_city, death_year, death_month, death_day, death_country, death_state, death_city, name_first, name_last, name_given, weight <sep>
69
+ SQL: SELECT count(*) FROM hall_of_fame
70
+ ```