VladGeekPro commited on
Commit
82b086c
·
1 Parent(s): 144bf42

ChangedDockerFile

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. sql_generator.py +346 -0
Dockerfile CHANGED
@@ -22,7 +22,7 @@ COPY --chown=user requirements.txt .
22
  RUN pip install --upgrade pip && pip install -r requirements.txt
23
 
24
  COPY --chown=user app.py ./
25
- COPY --chown=user sql_generator.py ./
26
  COPY --chown=user extractors/ ./extractors/
27
 
28
  EXPOSE 7860
 
22
  RUN pip install --upgrade pip && pip install -r requirements.txt
23
 
24
  COPY --chown=user app.py ./
25
+ COPY --chown=user expense_predictor.py ./
26
  COPY --chown=user extractors/ ./extractors/
27
 
28
  EXPOSE 7860
sql_generator.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Template-based SQL generator for SQLite.
3
+ Deterministic schema-aware NL→SQL with minimal templates and high accuracy.
4
+ Focuses only on: users, categories, suppliers, expenses, debts.
5
+ """
6
+
7
+ import os
8
+ import re
9
+ from datetime import datetime
10
+ from typing import Any, Dict, Optional, Tuple
11
+ from dataclasses import dataclass, asdict
12
+
13
+ # Core business schema only (for reference/documentation)
14
+ DEFAULT_DB_SCHEMA = (
15
+ "users : id int , name varchar , email varchar , created_at datetime , updated_at datetime | "
16
+ "categories : id int , name varchar , slug varchar , notes text , created_at datetime , updated_at datetime | "
17
+ "suppliers : id int , name varchar , slug varchar , category_id int , created_at datetime , updated_at datetime | "
18
+ "expenses : id int , user_id int , date date , category_id int , supplier_id int , sum numeric , notes text , created_at datetime , updated_at datetime | "
19
+ "debts : id int , date date , user_id int , debt_sum numeric , payment_status varchar , partial_sum numeric , date_paid date , created_at datetime , updated_at datetime"
20
+ )
21
+
22
+ _SQL_GENERATOR: Any | None = None
23
+ _MONTHS = {
24
+ "january": 1,
25
+ "february": 2,
26
+ "march": 3,
27
+ "april": 4,
28
+ "may": 5,
29
+ "june": 6,
30
+ "july": 7,
31
+ "august": 8,
32
+ "september": 9,
33
+ "october": 10,
34
+ "november": 11,
35
+ "december": 12,
36
+ }
37
+
38
+
39
+ @dataclass(frozen=True)
40
+ class SqlGenerationRequest:
41
+ question: str
42
+ limit: int = 200
43
+
44
+
45
+ def _normalize_text(text: str) -> str:
46
+ return re.sub(r"\s+", " ", text.lower()).strip()
47
+
48
+
49
+ def _contains_any(text: str, markers: tuple[str, ...]) -> bool:
50
+ return any(marker in text for marker in markers)
51
+
52
+
53
+ def _end_of_month(year: int, month: int) -> int:
54
+ if month == 12:
55
+ return 31
56
+ next_month = date(year, month + 1, 1)
57
+ current_month = date(year, month, 1)
58
+ return (next_month - current_month).days
59
+
60
+
61
+ def _extract_month_filter(question: str) -> tuple[str, str] | None:
62
+ text = _normalize_text(question)
63
+ for month_name, month_idx in _MONTHS.items():
64
+ if month_name in text:
65
+ year_match = re.search(r"\b(20\d{2})\b", text)
66
+ if not year_match:
67
+ continue
68
+ year = int(year_match.group(1))
69
+ day_end = _end_of_month(year, month_idx)
70
+ start = f"{year:04d}-{month_idx:02d}-01"
71
+ end = f"{year:04d}-{month_idx:02d}-{day_end:02d}"
72
+ return start, end
73
+ return None
74
+
75
+
76
+ def _extract_top_limit(question: str) -> int | None:
77
+ match = re.search(r"\btop\s+(\d{1,4})\b", _normalize_text(question))
78
+ if not match:
79
+ return None
80
+ return max(1, min(1000, int(match.group(1))))
81
+
82
+
83
+ def _extract_metric(question: str) -> tuple[str, str]:
84
+ text = _normalize_text(question)
85
+ if _contains_any(text, ("count", "how many", "number of")):
86
+ return "COUNT(*)", "items_count"
87
+ if _contains_any(text, ("average", "avg", "mean")):
88
+ return "AVG(e.sum)", "avg_amount"
89
+ if _contains_any(text, ("minimum", "lowest", "min ")):
90
+ return "MIN(e.sum)", "min_amount"
91
+ if _contains_any(text, ("maximum", "highest", "max ")):
92
+ return "MAX(e.sum)", "max_amount"
93
+ return "SUM(e.sum)", "total_amount"
94
+
95
+
96
+ def _extract_dimension(question: str) -> str | None:
97
+ text = _normalize_text(question)
98
+ if _contains_any(text, ("category", "categories")):
99
+ return "category"
100
+ if _contains_any(text, ("supplier", "suppliers", "vendor", "vendors")):
101
+ return "supplier"
102
+ if _contains_any(text, ("user", "users", "person")):
103
+ return "user"
104
+ return None
105
+
106
+
107
+ def _build_expenses_aggregate_sql(payload: SqlGenerationRequest) -> str:
108
+ question = _normalize_text(payload.question)
109
+ metric_expr, metric_alias = _extract_metric(question)
110
+ dimension = _extract_dimension(question)
111
+
112
+ select_parts = []
113
+ joins = []
114
+ group_by = []
115
+
116
+ if dimension == "category":
117
+ select_parts.append("c.name AS category_name")
118
+ joins.append("JOIN categories AS c ON c.id = e.category_id")
119
+ group_by.append("c.id, c.name")
120
+ elif dimension == "supplier":
121
+ select_parts.append("s.name AS supplier_name")
122
+ joins.append("JOIN suppliers AS s ON s.id = e.supplier_id")
123
+ group_by.append("s.id, s.name")
124
+ elif dimension == "user":
125
+ select_parts.append("u.name AS user_name")
126
+ joins.append("JOIN users AS u ON u.id = e.user_id")
127
+ group_by.append("u.id, u.name")
128
+
129
+ select_parts.append(f"{metric_expr} AS {metric_alias}")
130
+
131
+ filters = []
132
+ month_filter = _extract_month_filter(question)
133
+ if month_filter:
134
+ start, end = month_filter
135
+ filters.append(f"e.date BETWEEN '{start}' AND '{end}'")
136
+
137
+ where_clause = f" WHERE {' AND '.join(filters)}" if filters else ""
138
+ join_clause = f" {' '.join(joins)}" if joins else ""
139
+ group_clause = f" GROUP BY {', '.join(group_by)}" if group_by else ""
140
+
141
+ order_direction = "ASC" if " asc" in question or "ascending" in question else "DESC"
142
+ order_clause = f" ORDER BY {metric_alias} {order_direction}"
143
+
144
+ top_limit = _extract_top_limit(question)
145
+ final_limit = top_limit if top_limit is not None else payload.limit
146
+
147
+ return (
148
+ f"SELECT {', '.join(select_parts)} "
149
+ f"FROM expenses AS e"
150
+ f"{join_clause}"
151
+ f"{where_clause}"
152
+ f"{group_clause}"
153
+ f"{order_clause}"
154
+ f" LIMIT {final_limit}"
155
+ )
156
+
157
+
158
+ def _build_expenses_detail_sql(payload: SqlGenerationRequest) -> str:
159
+ question = _normalize_text(payload.question)
160
+ include_category = _contains_any(question, ("category", "categories"))
161
+ include_supplier = _contains_any(question, ("supplier", "suppliers", "vendor", "vendors"))
162
+ include_user = _contains_any(question, ("user", "users", "person"))
163
+
164
+ select_parts = ["e.date", "e.sum", "e.notes"]
165
+ joins = []
166
+
167
+ if include_category:
168
+ select_parts.append("c.name AS category_name")
169
+ joins.append("JOIN categories AS c ON c.id = e.category_id")
170
+ if include_supplier:
171
+ select_parts.append("s.name AS supplier_name")
172
+ joins.append("JOIN suppliers AS s ON s.id = e.supplier_id")
173
+ if include_user:
174
+ select_parts.append("u.name AS user_name")
175
+ joins.append("JOIN users AS u ON u.id = e.user_id")
176
+
177
+ filters = []
178
+ month_filter = _extract_month_filter(question)
179
+ if month_filter:
180
+ start, end = month_filter
181
+ filters.append(f"e.date BETWEEN '{start}' AND '{end}'")
182
+
183
+ where_clause = f" WHERE {' AND '.join(filters)}" if filters else ""
184
+ join_clause = f" {' '.join(joins)}" if joins else ""
185
+ order_clause = " ORDER BY e.date DESC"
186
+
187
+ return (
188
+ f"SELECT {', '.join(select_parts)} "
189
+ f"FROM expenses AS e"
190
+ f"{join_clause}"
191
+ f"{where_clause}"
192
+ f"{order_clause}"
193
+ f" LIMIT {payload.limit}"
194
+ )
195
+
196
+
197
+ def _build_debts_sql(payload: SqlGenerationRequest) -> str:
198
+ question = _normalize_text(payload.question)
199
+ with_user = _contains_any(question, ("user", "users", "person", "name"))
200
+
201
+ select_parts = ["d.date", "d.debt_sum", "d.payment_status"]
202
+ joins = []
203
+ if with_user:
204
+ select_parts.append("u.name AS user_name")
205
+ joins.append("LEFT JOIN users AS u ON u.id = d.user_id")
206
+
207
+ filters = []
208
+ if _contains_any(question, ("unpaid", "not paid", "open debt", "open debts")):
209
+ filters.append("d.payment_status = 'unpaid'")
210
+ elif _contains_any(question, ("paid", "closed debt", "closed debts")):
211
+ filters.append("d.payment_status = 'paid'")
212
+ elif _contains_any(question, ("partial", "partially")):
213
+ filters.append("d.payment_status = 'partial'")
214
+
215
+ month_filter = _extract_month_filter(question)
216
+ if month_filter:
217
+ start, end = month_filter
218
+ filters.append(f"d.date BETWEEN '{start}' AND '{end}'")
219
+
220
+ where_clause = f" WHERE {' AND '.join(filters)}" if filters else ""
221
+ join_clause = f" {' '.join(joins)}" if joins else ""
222
+ order_clause = " ORDER BY d.date DESC"
223
+
224
+ return (
225
+ f"SELECT {', '.join(select_parts)} "
226
+ f"FROM debts AS d"
227
+ f"{join_clause}"
228
+ f"{where_clause}"
229
+ f"{order_clause}"
230
+ f" LIMIT {payload.limit}"
231
+ )
232
+
233
+
234
+ def _generate_template_sql(payload: SqlGenerationRequest) -> str:
235
+ question = _normalize_text(payload.question)
236
+
237
+ debt_markers = ("debt", "debts", "payment_status", "unpaid", "partial", "paid")
238
+ aggregate_markers = (
239
+ "sum",
240
+ "total",
241
+ "group",
242
+ "grouped",
243
+ "top",
244
+ "count",
245
+ "average",
246
+ "avg",
247
+ "minimum",
248
+ "maximum",
249
+ )
250
+
251
+ if _contains_any(question, debt_markers):
252
+ return _build_debts_sql(payload)
253
+
254
+ if _contains_any(question, aggregate_markers):
255
+ return _build_expenses_aggregate_sql(payload)
256
+
257
+ return _build_expenses_detail_sql(payload)
258
+
259
+
260
+ def _get_sql_generator() -> Any:
261
+ global _SQL_GENERATOR
262
+
263
+ if _SQL_GENERATOR is None:
264
+ from transformers import pipeline
265
+
266
+ model_id = os.getenv("SQL_MODEL", "gaussalgo/T5-LM-Large-text2sql-spider")
267
+ _SQL_GENERATOR = pipeline(
268
+ task="text2text-generation",
269
+ model=model_id,
270
+ tokenizer=model_id,
271
+ device=-1,
272
+ )
273
+
274
+ return _SQL_GENERATOR
275
+
276
+
277
+ def _build_prompt(payload: SqlGenerationRequest) -> str:
278
+ # Optional fallback prompt for transformer model.
279
+ return f"Question: {payload.question} | {DEFAULT_DB_SCHEMA}"
280
+
281
+
282
+ def _normalize_sql(raw_sql: str, limit: int) -> str:
283
+ sql = (raw_sql or "").strip()
284
+ if not sql:
285
+ raise ValueError("SQL model returned an empty result.")
286
+
287
+ if "```" in sql:
288
+ parts = [part.strip() for part in sql.split("```") if part.strip()]
289
+ sql = parts[-1]
290
+
291
+ upper_sql = sql.upper()
292
+ sql_start = upper_sql.find("SELECT")
293
+ if sql_start == -1:
294
+ raise ValueError("Generated SQL is not a SELECT query.")
295
+
296
+ sql = sql[sql_start:]
297
+ if ";" in sql:
298
+ sql = sql.split(";", 1)[0].strip()
299
+
300
+ upper_sql = sql.upper()
301
+ forbidden = ("INSERT ", "UPDATE ", "DELETE ", "DROP ", "ALTER ", "PRAGMA ", "ATTACH ", "CREATE ", "REPLACE ")
302
+ if any(keyword in upper_sql for keyword in forbidden):
303
+ raise ValueError("Generated SQL contains forbidden statements.")
304
+
305
+ if not upper_sql.startswith("SELECT "):
306
+ raise ValueError("Only SELECT queries are allowed.")
307
+
308
+ aggregate_markers = ("COUNT(", "SUM(", "AVG(", "MIN(", "MAX(")
309
+ has_limit = " LIMIT " in upper_sql
310
+ if not has_limit and not any(marker in upper_sql for marker in aggregate_markers):
311
+ sql = f"{sql} LIMIT {limit}"
312
+
313
+ return sql
314
+
315
+
316
+ def generate_sql(question: str, limit: int = 200) -> str:
317
+ clean_question = (question or "").strip()
318
+ if not clean_question:
319
+ raise ValueError("Field 'query' is required.")
320
+
321
+ payload = SqlGenerationRequest(
322
+ question=clean_question,
323
+ limit=max(1, min(1000, int(limit))),
324
+ )
325
+
326
+ # Primary path: deterministic template engine for core tables.
327
+ template_sql = _generate_template_sql(payload)
328
+ if template_sql:
329
+ return _normalize_sql(template_sql, limit=payload.limit)
330
+
331
+ # Secondary path: optional model fallback.
332
+ if os.getenv("SQL_USE_LLM_FALLBACK", "false").strip().lower() not in {"1", "true", "yes", "on"}:
333
+ raise ValueError("Unable to map query to a supported SQL template.")
334
+
335
+ generator = _get_sql_generator()
336
+ prompt = _build_prompt(payload)
337
+ result = generator(
338
+ prompt,
339
+ max_new_tokens=512,
340
+ do_sample=False,
341
+ num_beams=4,
342
+ truncation=True,
343
+ )
344
+
345
+ generated_text = result[0].get("generated_text", "") if result else ""
346
+ return _normalize_sql(generated_text, limit=payload.limit)