HusnaManakkot commited on
Commit
04901e7
β€’
1 Parent(s): 2ef8e41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -2,21 +2,20 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from datasets import load_dataset
4
 
5
- # Load the Spider dataset
6
- spider_dataset = load_dataset("wikisql", split='train') # Load a subset of the dataset
7
 
8
- # Extract schema information from the Spider dataset
9
  table_names = set()
10
  column_names = set()
11
- for item in spider_dataset:
12
- for table in item['db_id']:
13
- table_names.add(table)
14
- for column in item['question']:
15
  column_names.add(column)
16
 
17
  # Load tokenizer and model
18
- tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL") # Update this to a model fine-tuned on Spider if available
19
- model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL") # Update this to a model fine-tuned on Spider if available
20
 
21
  def generate_sql_from_user_input(query):
22
  # Generate SQL for the user's query
@@ -40,9 +39,10 @@ interface = gr.Interface(
40
  fn=generate_sql_from_user_input,
41
  inputs=gr.Textbox(label="Enter your natural language query"),
42
  outputs=gr.Textbox(label="Generated SQL Query"),
43
- title="NL to SQL with T5 using Spider Dataset",
44
- description="This model generates an SQL query for your natural language input based on the Spider dataset."
45
  )
 
46
  # Launch the app
47
  if __name__ == "__main__":
48
  interface.launch()
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from datasets import load_dataset
4
 
5
+ # Load the WikiSQL dataset
6
+ wikisql_dataset = load_dataset("wikisql", split='train') # Load a subset of the dataset
7
 
8
+ # Extract schema information from the WikiSQL dataset
9
  table_names = set()
10
  column_names = set()
11
+ for item in wikisql_dataset:
12
+ table_names.add(item['table']['name'])
13
+ for column in item['table']['header']:
 
14
  column_names.add(column)
15
 
16
  # Load tokenizer and model
17
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
18
+ model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
19
 
20
  def generate_sql_from_user_input(query):
21
  # Generate SQL for the user's query
 
39
  fn=generate_sql_from_user_input,
40
  inputs=gr.Textbox(label="Enter your natural language query"),
41
  outputs=gr.Textbox(label="Generated SQL Query"),
42
+ title="NL to SQL with T5 using WikiSQL Dataset",
43
+ description="This model generates an SQL query for your natural language input based on the WikiSQL dataset."
44
  )
45
+
46
  # Launch the app
47
  if __name__ == "__main__":
48
  interface.launch()