Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -24,17 +24,15 @@ def generate_sql(nl_query, schema):
|
|
| 24 |
SQL Query:
|
| 25 |
"""
|
| 26 |
|
| 27 |
-
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
|
| 28 |
-
|
| 29 |
-
# Ensure dtype consistency to avoid errors
|
| 30 |
-
if model.dtype == torch.float16:
|
| 31 |
-
input_ids = input_ids.half() # Convert to float16
|
| 32 |
-
elif model.dtype == torch.bfloat16:
|
| 33 |
-
input_ids = input_ids.bfloat16() # Convert to bfloat16
|
| 34 |
|
| 35 |
with torch.no_grad():
|
| 36 |
output_ids = model.generate(input_ids, max_length=128, pad_token_id=tokenizer.eos_token_id)
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
# Decode and clean the output
|
| 39 |
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 40 |
sql_start = output_text.find("SQL Query:") + len("SQL Query:")
|
|
@@ -75,7 +73,7 @@ def get_schema(db_path):
|
|
| 75 |
return schema
|
| 76 |
|
| 77 |
# --- Streamlit UI ---
|
| 78 |
-
st.title("
|
| 79 |
st.write("Convert natural language questions into SQL queries and execute them.")
|
| 80 |
|
| 81 |
# Database selection
|
|
@@ -107,7 +105,7 @@ elif db_option == "Enter schema manually":
|
|
| 107 |
schema = st.text_area("Enter your schema:")
|
| 108 |
|
| 109 |
if schema:
|
| 110 |
-
st.subheader("
|
| 111 |
st.code(schema, language="sql")
|
| 112 |
|
| 113 |
# Query input
|
|
@@ -118,7 +116,7 @@ if st.button("Generate SQL Query"):
|
|
| 118 |
st.error("❌ Please provide a database or schema first.")
|
| 119 |
else:
|
| 120 |
sql_query = generate_sql(user_query, schema)
|
| 121 |
-
st.subheader("
|
| 122 |
st.code(sql_query, language="sql")
|
| 123 |
|
| 124 |
# Execute SQL if database exists
|
|
@@ -137,4 +135,4 @@ if st.button("Generate SQL Query"):
|
|
| 137 |
else:
|
| 138 |
st.error(f"❌ SQL Execution Error: {result}")
|
| 139 |
else:
|
| 140 |
-
st.info("
|
|
|
|
| 24 |
SQL Query:
|
| 25 |
"""
|
| 26 |
|
| 27 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device, dtype=torch.long) # Ensure Long dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
with torch.no_grad():
|
| 30 |
output_ids = model.generate(input_ids, max_length=128, pad_token_id=tokenizer.eos_token_id)
|
| 31 |
|
| 32 |
+
# If model outputs in float16 or bfloat16, convert back to long/int
|
| 33 |
+
if output_ids.dtype in [torch.float16, torch.bfloat16]:
|
| 34 |
+
output_ids = output_ids.to(dtype=torch.long)
|
| 35 |
+
|
| 36 |
# Decode and clean the output
|
| 37 |
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 38 |
sql_start = output_text.find("SQL Query:") + len("SQL Query:")
|
|
|
|
| 73 |
return schema
|
| 74 |
|
| 75 |
# --- Streamlit UI ---
|
| 76 |
+
st.title("AI-Powered Text-to-SQL Generator")
|
| 77 |
st.write("Convert natural language questions into SQL queries and execute them.")
|
| 78 |
|
| 79 |
# Database selection
|
|
|
|
| 105 |
schema = st.text_area("Enter your schema:")
|
| 106 |
|
| 107 |
if schema:
|
| 108 |
+
st.subheader("Extracted/Provided Schema:")
|
| 109 |
st.code(schema, language="sql")
|
| 110 |
|
| 111 |
# Query input
|
|
|
|
| 116 |
st.error("❌ Please provide a database or schema first.")
|
| 117 |
else:
|
| 118 |
sql_query = generate_sql(user_query, schema)
|
| 119 |
+
st.subheader("Generated SQL Query:")
|
| 120 |
st.code(sql_query, language="sql")
|
| 121 |
|
| 122 |
# Execute SQL if database exists
|
|
|
|
| 135 |
else:
|
| 136 |
st.error(f"❌ SQL Execution Error: {result}")
|
| 137 |
else:
|
| 138 |
+
st.info("No database provided, only SQL query was generated.")
|