Johannes commited on
Commit
7b065dc
1 Parent(s): 2be3343
Files changed (2) hide show
  1. app.py +23 -16
  2. example_strings.py +12 -24
app.py CHANGED
@@ -3,14 +3,13 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from example_strings import example1, example2, example3
4
 
5
 
6
- # tokenizer6B = AutoTokenizer.from_pretrained(f"NumbersStation/nsql-6B")
7
- # model6B = AutoModelForCausalLM.from_pretrained(f"NumbersStation/nsql-6B")
8
-
9
- # tokenizer2B = AutoTokenizer.from_pretrained(f"NumbersStation/nsql-2B")
10
- # model2B = AutoModelForCausalLM.from_pretrained(f"NumbersStation/nsql-2B")
11
-
12
- # tokenizer350M = AutoTokenizer.from_pretrained(f"NumbersStation/nsql-2B")
13
- # model350M = AutoModelForCausalLM.from_pretrained(f"NumbersStation/nsql-2B")
14
 
15
 
16
  def load_model(model_name: str):
@@ -19,8 +18,15 @@ def load_model(model_name: str):
19
  return tokenizer, model
20
 
21
 
22
- def infer(input_text, model_choice):
 
 
 
 
23
  tokenizer, model = load_model(model_choice)
 
 
 
24
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids
25
  generated_ids = model.generate(input_ids, max_length=500)
26
  return (tokenizer.decode(generated_ids[0], skip_special_tokens=True))
@@ -31,7 +37,7 @@ description = """The NSQL model family was published by [Numbers Station](https:
31
  - [nsql-2B](https://huggingface.co/NumbersStation/nsql-2B)
32
  - [nsql-350M]((https://huggingface.co/NumbersStation/nsql-350M))
33
 
34
- This demo let's you choose which one you want to use and provides the three examples you can also find in their model cards.
35
 
36
  In general you should first provide the table schemas of the tables you have questions about and then prompt it with a natural language question.
37
  The model will then generate a SQL query that you can run against your database.
@@ -41,10 +47,11 @@ iface = gr.Interface(
41
  title="Text to SQL with NSQL",
42
  description=description,
43
  fn=infer,
44
- inputs=["text",
45
- gr.Dropdown(["nsql-6B", "nsql-2B", "nsql-350M"], value="nsql-6B")],
 
 
 
46
  outputs="text",
47
- examples=[[example1, "nsql-350M"],
48
- [example2, "nsql-2B"],
49
- [example3, "nsql-350M"]])
50
- iface.launch()
 
3
  from example_strings import example1, example2, example3
4
 
5
 
6
+ template_str = """{table_schemas}
7
+ \n \n
8
+ {task_spec}
9
+ \n \n
10
+ {prompt}
11
+ \n \n
12
+ SELECT"""
 
13
 
14
 
15
  def load_model(model_name: str):
 
18
  return tokenizer, model
19
 
20
 
21
+ def build_complete_prompt(table_schemas: str, task_spec: str, prompt: str) -> str:
22
+ return template_str.format(table_schemas=table_schemas, task_spec=task_spec, prompt=prompt)
23
+
24
+
25
+ def infer(table_schemas: str, task_spec: str, prompt: str, model_choice: str = "nsql-350M"):
26
  tokenizer, model = load_model(model_choice)
27
+
28
+ input_text = build_complete_prompt(table_schemas, task_spec, prompt)
29
+
30
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids
31
  generated_ids = model.generate(input_ids, max_length=500)
32
  return (tokenizer.decode(generated_ids[0], skip_special_tokens=True))
 
37
  - [nsql-2B](https://huggingface.co/NumbersStation/nsql-2B)
38
  - [nsql-350M]((https://huggingface.co/NumbersStation/nsql-350M))
39
 
40
+ For now you can only use the 350M version of the model here, as the file size of the other models exceeds the max memory available in spaces.
41
 
42
  In general you should first provide the table schemas of the tables you have questions about and then prompt it with a natural language question.
43
  The model will then generate a SQL query that you can run against your database.
 
47
  title="Text to SQL with NSQL",
48
  description=description,
49
  fn=infer,
50
+ inputs=[gr.Text(label="Table schemas", placeholder="Insert your table schemas here"),
51
+ gr.Text(label="Specify Task", value="Using valid SQLite, answer the following questions for the tables provided above."),
52
+ gr.Text(label="Prompt", placeholder="Put your natural language prompt here"),
53
+ # gr.Dropdown(["nsql-6B", "nsql-2B", "nsql-350M"], value="nsql-6B")
54
+ ],
55
  outputs="text",
56
+ examples=[example1, example2, example3])
57
+ iface.launch()
 
 
example_strings.py CHANGED
@@ -1,4 +1,4 @@
1
- example1 = """CREATE TABLE stadium (
2
  stadium_id number,
3
  location text,
4
  name text,
@@ -29,28 +29,20 @@ CREATE TABLE concert (
29
  CREATE TABLE singer_in_concert (
30
  concert_id number,
31
  singer_id text
32
- )
33
-
34
- -- Using valid SQLite, answer the following questions for the tables provided above.
35
-
36
- -- What is the maximum, the average, and the minimum capacity of stadiums ?
37
 
38
- SELECT"""
39
-
40
- example2 = """CREATE TABLE stadium (
41
  stadium_id number,
42
  location text,
43
  name text,
44
  capacity number,
45
- )
46
-
47
- -- Using valid SQLite, answer the following questions for the tables provided above.
48
 
49
- -- how many stadiums in total?
50
-
51
- SELECT"""
52
-
53
- example3 = """CREATE TABLE work_orders (
54
  ID NUMBER,
55
  CREATED_AT TEXT,
56
  COST FLOAT,
@@ -59,10 +51,6 @@ example3 = """CREATE TABLE work_orders (
59
  IS_OPEN BOOLEAN,
60
  IS_OVERDUE BOOLEAN,
61
  COUNTRY_NAME TEXT,
62
- )
63
-
64
- -- Using valid SQLite, answer the following questions for the tables provided above.
65
-
66
- -- how many work orders are open?
67
-
68
- SELECT"""
 
1
+ example1 = ["""CREATE TABLE stadium (
2
  stadium_id number,
3
  location text,
4
  name text,
 
29
  CREATE TABLE singer_in_concert (
30
  concert_id number,
31
  singer_id text
32
+ )""",
33
+ "Using valid SQLite, answer the following questions for the tables provided above.",
34
+ "What is the maximum, the average, and the minimum capacity of stadiums ?"]
 
 
35
 
36
+ example2 = ["""CREATE TABLE stadium (
 
 
37
  stadium_id number,
38
  location text,
39
  name text,
40
  capacity number,
41
+ )""",
42
+ "Using valid SQLite, answer the following questions for the tables provided above.",
43
+ "how many stadiums in total?"]
44
 
45
+ example3 = ["""CREATE TABLE work_orders (
 
 
 
 
46
  ID NUMBER,
47
  CREATED_AT TEXT,
48
  COST FLOAT,
 
51
  IS_OPEN BOOLEAN,
52
  IS_OVERDUE BOOLEAN,
53
  COUNTRY_NAME TEXT,
54
+ )""",
55
+ "Using valid SQLite, answer the following questions for the tables provided above.",
56
+ "how many work orders are open?"]