Shodnotantelope2 commited on
Commit
2b32bf4
·
verified ·
1 Parent(s): 921b2e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -0
app.py CHANGED
@@ -25,6 +25,13 @@ SQL Query:
25
  """
26
 
27
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
 
 
 
 
 
 
 
28
  with torch.no_grad():
29
  output_ids = model.generate(input_ids, max_length=128, pad_token_id=tokenizer.eos_token_id)
30
 
@@ -37,6 +44,7 @@ SQL Query:
37
  sql_query = re.sub(r"```sql|```", "", sql_query).split("###")[0].strip()
38
  return sql_query
39
 
 
40
  def execute_sql(sql_query, db_path):
41
  """Execute the generated SQL query on the provided database."""
42
  try:
 
25
  """
26
 
27
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
28
+
29
+ # Ensure dtype consistency to avoid errors
30
+ if model.dtype == torch.float16:
31
+ input_ids = input_ids.half() # Convert to float16
32
+ elif model.dtype == torch.bfloat16:
33
+ input_ids = input_ids.bfloat16() # Convert to bfloat16
34
+
35
  with torch.no_grad():
36
  output_ids = model.generate(input_ids, max_length=128, pad_token_id=tokenizer.eos_token_id)
37
 
 
44
  sql_query = re.sub(r"```sql|```", "", sql_query).split("###")[0].strip()
45
  return sql_query
46
 
47
+
48
  def execute_sql(sql_query, db_path):
49
  """Execute the generated SQL query on the provided database."""
50
  try: