GlastonR commited on
Commit
bb1df26
·
verified ·
1 Parent(s): 578a297

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -46
app.py CHANGED
@@ -1,67 +1,84 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
- # Load the models and tokenizers
5
- question_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
6
- question_model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
7
 
8
- sql_tokenizer = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
9
- sql_model = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
10
 
11
- # Function to generate a question based on a table schema
12
- def generate_question(tables):
13
- # Convert table schema to string
14
- table_str = ", ".join([f"{table}: ({', '.join(cols)})" for table, cols in tables.items()])
15
- prompt = f"Generate a question based on the following table schema: {table_str}"
16
 
17
- # Tokenize input and generate question
18
- input_ids = question_tokenizer(prompt, return_tensors="pt").input_ids
19
- output = question_model.generate(input_ids, num_beams=5, max_length=50)
20
- question = question_tokenizer.decode(output[0], skip_special_tokens=True)
21
- return question
 
 
 
 
 
 
 
 
 
22
 
23
- # Function to prepare input data for SQL generation
24
- def prepare_sql_input(question, tables):
25
- table_str = ", ".join([f"{table}({', '.join(cols)})" for table, cols in tables.items()])
26
- prompt = f"Convert the question and table schema into an SQL query. Tables: {table_str}. Question: {question}"
27
 
28
- input_ids = sql_tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
 
 
 
29
  return input_ids
30
 
31
- # Inference function for SQL generation
32
- def generate_sql(question, tables):
33
- input_data = prepare_sql_input(question, tables)
34
- input_data = input_data.to(sql_model.device)
35
- outputs = sql_model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
36
- sql_query = sql_tokenizer.decode(outputs[0], skip_special_tokens=True)
37
- return sql_query
38
 
39
  # Streamlit UI
40
  def main():
41
  st.title("Multi-Model: Text to SQL and Question Generation")
42
 
43
- # Input table schema
44
- tables_input = st.text_area("Enter table schemas (in JSON format):",
45
- '{"people_name": ["id", "name"], "people_age": ["people_id", "age"]}')
46
- try:
47
- tables = eval(tables_input) # Convert string to dict safely
48
- except:
49
- tables = {}
50
 
51
- # If tables are provided, generate a question
52
- if tables:
53
- generated_question = generate_question(tables)
54
- st.write(f"Generated Question: {generated_question}")
 
 
 
 
 
55
 
56
- # Input question manually if needed
57
- question = st.text_area("Enter your question (optional):", generated_question if tables else "")
 
 
 
 
58
 
59
- if st.button("Generate SQL Query"):
60
- if question and tables:
61
- sql_query = generate_sql(question, tables)
62
- st.write(f"Generated SQL Query: {sql_query}")
63
- else:
64
- st.write("Please enter both a question and table schemas.")
 
 
 
 
 
65
 
66
  if __name__ == "__main__":
67
  main()
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
+ # Load the models
5
+ tokenizer_sql = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
6
+ model_sql = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
7
 
8
+ tokenizer_question = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
9
+ model_question = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
10
 
11
+ # Function to create the prompt for SQL model
12
+ def get_prompt_sql(tables, question):
13
+ return f"""convert question and table into SQL query. tables: {tables}. question: {question}"""
 
 
14
 
15
+ # Function to prepare input data for the SQL model
16
+ def prepare_input_sql(question: str, tables: dict):
17
+ tables = [f"""{table_name}({','.join(tables[table_name])})""" for table_name in tables]
18
+ tables = ", ".join(tables)
19
+ prompt = get_prompt_sql(tables, question)
20
+ input_ids = tokenizer_sql(prompt, max_length=512, return_tensors="pt").input_ids
21
+ return input_ids
22
+
23
+ # Inference function for the SQL model
24
+ def inference_sql(question: str, tables: dict) -> str:
25
+ input_data = prepare_input_sql(question=question, tables=tables)
26
+ input_data = input_data.to(model_sql.device)
27
+ outputs = model_sql.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
28
+ return tokenizer_sql.decode(outputs[0], skip_special_tokens=True)
29
 
30
+ # Function to create the prompt for Question Generation model
31
+ def get_prompt_question(context):
32
+ return f"generate a question from the following context: {context}"
 
33
 
34
+ # Function to prepare input data for the Question Generation model
35
+ def prepare_input_question(context: str):
36
+ prompt = get_prompt_question(context)
37
+ input_ids = tokenizer_question(prompt, max_length=512, return_tensors="pt").input_ids
38
  return input_ids
39
 
40
+ # Inference function for the Question Generation model
41
+ def inference_question(context: str) -> str:
42
+ input_data = prepare_input_question(context)
43
+ input_data = input_data.to(model_question.device)
44
+ outputs = model_question.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
45
+ return tokenizer_question.decode(outputs[0], skip_special_tokens=True)
 
46
 
47
  # Streamlit UI
48
  def main():
49
  st.title("Multi-Model: Text to SQL and Question Generation")
50
 
51
+ # Model selection
52
+ model_choice = st.selectbox("Select a model", ["Text to SQL", "Question Generation"])
 
 
 
 
 
53
 
54
+ # Input question and table schema for SQL model
55
+ if model_choice == "Text to SQL":
56
+ st.subheader("Text to SQL Model")
57
+ question = st.text_area("Enter your question:")
58
+ tables_input = st.text_area("Enter table schemas (in JSON format):", '{"people_name": ["id", "name"], "people_age": ["people_id", "age"]}')
59
+ try:
60
+ tables = eval(tables_input) # Convert string to dict safely
61
+ except:
62
+ tables = {}
63
 
64
+ if st.button("Generate SQL Query"):
65
+ if question and tables:
66
+ sql_query = inference_sql(question, tables)
67
+ st.write(f"Generated SQL Query: {sql_query}")
68
+ else:
69
+ st.write("Please enter both a question and table schemas.")
70
 
71
+ # Input context for Question Generation model
72
+ elif model_choice == "Question Generation":
73
+ st.subheader("Question Generation Model")
74
+ context = st.text_area("Enter context:")
75
+
76
+ if st.button("Generate Question"):
77
+ if context:
78
+ generated_question = inference_question(context)
79
+ st.write(f"Generated Question: {generated_question}")
80
+ else:
81
+ st.write("Please enter context for question generation.")
82
 
83
  if __name__ == "__main__":
84
  main()