Shodnotantelope2 commited on
Commit
67ee210
·
verified ·
1 Parent(s): 2b32bf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -24,17 +24,15 @@ def generate_sql(nl_query, schema):
24
  SQL Query:
25
  """
26
 
27
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
28
-
29
- # Ensure dtype consistency to avoid errors
30
- if model.dtype == torch.float16:
31
- input_ids = input_ids.half() # Convert to float16
32
- elif model.dtype == torch.bfloat16:
33
- input_ids = input_ids.bfloat16() # Convert to bfloat16
34
 
35
  with torch.no_grad():
36
  output_ids = model.generate(input_ids, max_length=128, pad_token_id=tokenizer.eos_token_id)
37
 
 
 
 
 
38
  # Decode and clean the output
39
  output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
40
  sql_start = output_text.find("SQL Query:") + len("SQL Query:")
@@ -75,7 +73,7 @@ def get_schema(db_path):
75
  return schema
76
 
77
  # --- Streamlit UI ---
78
- st.title("🔹 AI-Powered Text-to-SQL Generator")
79
  st.write("Convert natural language questions into SQL queries and execute them.")
80
 
81
  # Database selection
@@ -107,7 +105,7 @@ elif db_option == "Enter schema manually":
107
  schema = st.text_area("Enter your schema:")
108
 
109
  if schema:
110
- st.subheader("📌 Extracted/Provided Schema:")
111
  st.code(schema, language="sql")
112
 
113
  # Query input
@@ -118,7 +116,7 @@ if st.button("Generate SQL Query"):
118
  st.error("❌ Please provide a database or schema first.")
119
  else:
120
  sql_query = generate_sql(user_query, schema)
121
- st.subheader("🔹 Generated SQL Query:")
122
  st.code(sql_query, language="sql")
123
 
124
  # Execute SQL if database exists
@@ -137,4 +135,4 @@ if st.button("Generate SQL Query"):
137
  else:
138
  st.error(f"❌ SQL Execution Error: {result}")
139
  else:
140
- st.info("📌 No database provided, only SQL query was generated.")
 
24
  SQL Query:
25
  """
26
 
27
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device, dtype=torch.long) # Ensure Long dtype
 
 
 
 
 
 
28
 
29
  with torch.no_grad():
30
  output_ids = model.generate(input_ids, max_length=128, pad_token_id=tokenizer.eos_token_id)
31
 
32
+ # If model outputs in float16 or bfloat16, convert back to long/int
33
+ if output_ids.dtype in [torch.float16, torch.bfloat16]:
34
+ output_ids = output_ids.to(dtype=torch.long)
35
+
36
  # Decode and clean the output
37
  output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
38
  sql_start = output_text.find("SQL Query:") + len("SQL Query:")
 
73
  return schema
74
 
75
  # --- Streamlit UI ---
76
+ st.title("AI-Powered Text-to-SQL Generator")
77
  st.write("Convert natural language questions into SQL queries and execute them.")
78
 
79
  # Database selection
 
105
  schema = st.text_area("Enter your schema:")
106
 
107
  if schema:
108
+ st.subheader("Extracted/Provided Schema:")
109
  st.code(schema, language="sql")
110
 
111
  # Query input
 
116
  st.error("❌ Please provide a database or schema first.")
117
  else:
118
  sql_query = generate_sql(user_query, schema)
119
+ st.subheader("Generated SQL Query:")
120
  st.code(sql_query, language="sql")
121
 
122
  # Execute SQL if database exists
 
135
  else:
136
  st.error(f"❌ SQL Execution Error: {result}")
137
  else:
138
+ st.info("No database provided, only SQL query was generated.")