Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,6 +25,13 @@ SQL Query:
|
|
| 25 |
"""
|
| 26 |
|
| 27 |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
with torch.no_grad():
|
| 29 |
output_ids = model.generate(input_ids, max_length=128, pad_token_id=tokenizer.eos_token_id)
|
| 30 |
|
|
@@ -37,6 +44,7 @@ SQL Query:
|
|
| 37 |
sql_query = re.sub(r"```sql|```", "", sql_query).split("###")[0].strip()
|
| 38 |
return sql_query
|
| 39 |
|
|
|
|
| 40 |
def execute_sql(sql_query, db_path):
|
| 41 |
"""Execute the generated SQL query on the provided database."""
|
| 42 |
try:
|
|
|
|
| 25 |
"""
|
| 26 |
|
| 27 |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
|
| 28 |
+
|
| 29 |
+
# Ensure dtype consistency to avoid errors
|
| 30 |
+
if model.dtype == torch.float16:
|
| 31 |
+
input_ids = input_ids.half() # Convert to float16
|
| 32 |
+
elif model.dtype == torch.bfloat16:
|
| 33 |
+
input_ids = input_ids.bfloat16() # Convert to bfloat16
|
| 34 |
+
|
| 35 |
with torch.no_grad():
|
| 36 |
output_ids = model.generate(input_ids, max_length=128, pad_token_id=tokenizer.eos_token_id)
|
| 37 |
|
|
|
|
| 44 |
sql_query = re.sub(r"```sql|```", "", sql_query).split("###")[0].strip()
|
| 45 |
return sql_query
|
| 46 |
|
| 47 |
+
|
| 48 |
def execute_sql(sql_query, db_path):
|
| 49 |
"""Execute the generated SQL query on the provided database."""
|
| 50 |
try:
|