nilotpaldhar2004 commited on
Commit
4f43f55
Β·
unverified Β·
1 Parent(s): 49d3371

Enhance SQL generation in generate_sql function

Browse files

Refactor SQL generation logic with enhanced regex and improved fallback handling.

Files changed (1) hide show
  1. app.py +50 -52
app.py CHANGED
@@ -81,75 +81,73 @@ def get_schema(db_bytes: bytes) -> str:
81
 
82
 
83
  def generate_sql(question: str, schema: str) -> str:
84
- """Run T5 inference to produce SQL."""
85
- # Extract table name from schema
 
 
 
 
86
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
87
  table_name = table_match.group(1) if table_match else "data"
88
  quoted = f'"{table_name}"'
89
  col_match = re.findall(r'"(\w+)"', schema)
90
-
91
- # ── Rule-based shortcuts (fast + accurate) ────────────────────────────────
92
  q = question.lower().strip()
93
- if re.search(r'show.*(first|top).*\d+|first.*\d+.*row|top.*\d+', q):
94
- n = re.search(r'\d+', q)
95
- return f'SELECT * FROM {quoted} LIMIT {n.group() if n else 10}'
96
- if re.search(r'(show|display|get|give).*(first|all).*row|first.*row|show.*row', q):
97
- return f'SELECT * FROM {quoted} LIMIT 10'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  if re.search(r'count.*(total|all|record|row)|total.*(record|row|count)|how many', q):
99
  return f'SELECT COUNT(*) FROM {quoted}'
100
- if re.search(r'show.*(all|every).*row|all.*row|select all', q):
101
- return f'SELECT * FROM {quoted} LIMIT 50'
102
- if re.search(r'average|avg', q) and col_match:
103
- num_col = next((c for c in col_match if re.search(r'pm|aqi|no|co|so|o3|benzene|toluene|xylene', c, re.I)), col_match[2])
104
- return f'SELECT AVG("{num_col}") FROM {quoted}'
105
- if re.search(r'unique|distinct', q) and col_match:
106
- return f'SELECT COUNT(DISTINCT "{col_match[0]}") FROM {quoted}'
107
- if re.search(r'group by', q) and col_match:
108
- return f'SELECT "{col_match[0]}", COUNT(*) FROM {quoted} GROUP BY "{col_match[0]}"'
109
- if re.search(r'max|maximum|highest', q) and col_match:
110
- num_col = col_match[1] if len(col_match) > 1 else col_match[0]
111
- return f'SELECT MAX("{num_col}") FROM {quoted}'
112
- if re.search(r'min|minimum|lowest', q) and col_match:
113
- num_col = col_match[1] if len(col_match) > 1 else col_match[0]
114
- return f'SELECT MIN("{num_col}") FROM {quoted}'
115
-
116
- # ── T5 model fallback ─────────────────────────────────────────────────────
117
  col_hint = ", ".join(col_match) if col_match else ""
118
- prompt = f"tables:\n{schema}\ncolumns: {col_hint}\nquery for: {question}"
119
- inputs = tokenizer(
120
- prompt,
121
- return_tensors="pt",
122
- truncation=True,
123
- max_length=512,
124
- ).to(DEVICE)
125
  with torch.no_grad():
126
- outputs = model.generate(
127
- **inputs,
128
- max_new_tokens=MAX_NEW_TOKENS,
129
- num_beams=4,
130
- early_stopping=True,
131
- )
132
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
133
 
134
- # Fix 1: replace any FROM/JOIN table reference (quoted or unquoted) with correct table
135
  sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
136
- sql = re.sub(r'\bJOIN\s+("?\w+"?)', f'JOIN {quoted}', sql, flags=re.IGNORECASE)
137
-
138
- # Fix 2: strip junk tokens after table name before LIMIT/WHERE/ORDER etc.
139
- # e.g. FROM "city_day" Datetime LIMIT 10 β†’ FROM "city_day" LIMIT 10
140
- sql = re.sub(
141
- r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|LEFT|RIGHT|INNER|ON|AND|OR|\d)(\w+)',
142
- r'\1',
143
- sql, flags=re.IGNORECASE
144
- )
145
-
146
- # Fix 3: fallback if no SELECT at all
147
  if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
148
  sql = f'SELECT * FROM {quoted} LIMIT 10'
149
 
150
  return sql
151
 
152
-
153
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
154
  """Run SQL against the in-memory SQLite DB."""
155
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
 
81
 
82
 
83
  def generate_sql(question: str, schema: str) -> str:
84
+ """
85
+ Enhanced Hybrid SQL Engine.
86
+ Priority 1: Smart Regex (Deterministic & Instant)
87
+ Priority 2: T5 Transformer (Probabilistic Fallback)
88
+ """
89
+ # 1. Context Extraction
90
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
91
  table_name = table_match.group(1) if table_match else "data"
92
  quoted = f'"{table_name}"'
93
  col_match = re.findall(r'"(\w+)"', schema)
94
+
 
95
  q = question.lower().strip()
96
+
97
+ # 2. Smart Column Detection
98
+ # Searches for a column name from the schema within the user's question
99
+ target_col = None
100
+ for col in col_match:
101
+ if col.lower() in q:
102
+ target_col = col
103
+ break
104
+
105
+ # 3. Enhanced Rule-Based Shortcuts (Smart Logic)
106
+
107
+ # DISTINCT/UNIQUE COUNT
108
+ if re.search(r'unique|distinct', q):
109
+ col = target_col if target_col else (col_match[0] if col_match else "*")
110
+ return f'SELECT COUNT(DISTINCT "{col}") FROM {quoted}'
111
+
112
+ # GROUP BY
113
+ if re.search(r'group.*by|per|each', q):
114
+ col = target_col if target_col else (col_match[0] if col_match else "data")
115
+ return f'SELECT "{col}", COUNT(*) FROM {quoted} GROUP BY "{col}"'
116
+
117
+ # AVERAGE (With semantic fallback for your city_day dataset)
118
+ if re.search(r'average|avg|mean', q):
119
+ num_col = target_col if target_col else next((c for c in col_match if re.search(r'pm|aqi|no|co|so|o3|benzene|val|amt', c, re.I)), col_match[2] if len(col_match)>2 else col_match[0])
120
+ return f'SELECT AVG("{num_col}") FROM {quoted}'
121
+
122
+ # TOTAL RECORDS
123
  if re.search(r'count.*(total|all|record|row)|total.*(record|row|count)|how many', q):
124
  return f'SELECT COUNT(*) FROM {quoted}'
125
+
126
+ # LIMIT/TOP ROWS
127
+ if re.search(r'show|display|get|first|top', q):
128
+ n_match = re.search(r'\d+', q)
129
+ limit = n_match.group() if n_match else 10
130
+ return f'SELECT * FROM {quoted} LIMIT {limit}'
131
+
132
+ # 4. T5 Model Fallback
 
 
 
 
 
 
 
 
 
133
  col_hint = ", ".join(col_match) if col_match else ""
134
+ prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}"
135
+
136
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
 
 
 
 
137
  with torch.no_grad():
138
+ outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, num_beams=4, early_stopping=True)
139
+
 
 
 
 
140
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
141
 
142
+ # Post-inference cleaning (Crucial for SQLite stability)
143
  sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
144
+ sql = re.sub(r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|ON|AND|OR)(\w+)', r'\1', sql, flags=re.IGNORECASE)
145
+
 
 
 
 
 
 
 
 
 
146
  if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
147
  sql = f'SELECT * FROM {quoted} LIMIT 10'
148
 
149
  return sql
150
 
 
151
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
152
  """Run SQL against the in-memory SQLite DB."""
153
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: