Spaces:
Running
Running
VladGeekPro commited on
Commit ·
82b086c
1
Parent(s): 144bf42
ChangedDockerFile
Browse files- Dockerfile +1 -1
- sql_generator.py +346 -0
Dockerfile
CHANGED
|
@@ -22,7 +22,7 @@ COPY --chown=user requirements.txt .
|
|
| 22 |
RUN pip install --upgrade pip && pip install -r requirements.txt
|
| 23 |
|
| 24 |
COPY --chown=user app.py ./
|
| 25 |
-
COPY --chown=user
|
| 26 |
COPY --chown=user extractors/ ./extractors/
|
| 27 |
|
| 28 |
EXPOSE 7860
|
|
|
|
| 22 |
RUN pip install --upgrade pip && pip install -r requirements.txt
|
| 23 |
|
| 24 |
COPY --chown=user app.py ./
|
| 25 |
+
COPY --chown=user expense_predictor.py ./
|
| 26 |
COPY --chown=user extractors/ ./extractors/
|
| 27 |
|
| 28 |
EXPOSE 7860
|
sql_generator.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Template-based SQL generator for SQLite.
|
| 3 |
+
Deterministic schema-aware NL→SQL with minimal templates and high accuracy.
|
| 4 |
+
Focuses only on: users, categories, suppliers, expenses, debts.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from typing import Any, Dict, Optional, Tuple
|
| 11 |
+
from dataclasses import dataclass, asdict
|
| 12 |
+
|
| 13 |
+
# Core business schema only (for reference/documentation)
|
| 14 |
+
DEFAULT_DB_SCHEMA = (
|
| 15 |
+
"users : id int , name varchar , email varchar , created_at datetime , updated_at datetime | "
|
| 16 |
+
"categories : id int , name varchar , slug varchar , notes text , created_at datetime , updated_at datetime | "
|
| 17 |
+
"suppliers : id int , name varchar , slug varchar , category_id int , created_at datetime , updated_at datetime | "
|
| 18 |
+
"expenses : id int , user_id int , date date , category_id int , supplier_id int , sum numeric , notes text , created_at datetime , updated_at datetime | "
|
| 19 |
+
"debts : id int , date date , user_id int , debt_sum numeric , payment_status varchar , partial_sum numeric , date_paid date , created_at datetime , updated_at datetime"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
_SQL_GENERATOR: Any | None = None
|
| 23 |
+
_MONTHS = {
|
| 24 |
+
"january": 1,
|
| 25 |
+
"february": 2,
|
| 26 |
+
"march": 3,
|
| 27 |
+
"april": 4,
|
| 28 |
+
"may": 5,
|
| 29 |
+
"june": 6,
|
| 30 |
+
"july": 7,
|
| 31 |
+
"august": 8,
|
| 32 |
+
"september": 9,
|
| 33 |
+
"october": 10,
|
| 34 |
+
"november": 11,
|
| 35 |
+
"december": 12,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass(frozen=True)
|
| 40 |
+
class SqlGenerationRequest:
|
| 41 |
+
question: str
|
| 42 |
+
limit: int = 200
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _normalize_text(text: str) -> str:
|
| 46 |
+
return re.sub(r"\s+", " ", text.lower()).strip()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _contains_any(text: str, markers: tuple[str, ...]) -> bool:
|
| 50 |
+
return any(marker in text for marker in markers)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _end_of_month(year: int, month: int) -> int:
|
| 54 |
+
if month == 12:
|
| 55 |
+
return 31
|
| 56 |
+
next_month = date(year, month + 1, 1)
|
| 57 |
+
current_month = date(year, month, 1)
|
| 58 |
+
return (next_month - current_month).days
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _extract_month_filter(question: str) -> tuple[str, str] | None:
|
| 62 |
+
text = _normalize_text(question)
|
| 63 |
+
for month_name, month_idx in _MONTHS.items():
|
| 64 |
+
if month_name in text:
|
| 65 |
+
year_match = re.search(r"\b(20\d{2})\b", text)
|
| 66 |
+
if not year_match:
|
| 67 |
+
continue
|
| 68 |
+
year = int(year_match.group(1))
|
| 69 |
+
day_end = _end_of_month(year, month_idx)
|
| 70 |
+
start = f"{year:04d}-{month_idx:02d}-01"
|
| 71 |
+
end = f"{year:04d}-{month_idx:02d}-{day_end:02d}"
|
| 72 |
+
return start, end
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _extract_top_limit(question: str) -> int | None:
|
| 77 |
+
match = re.search(r"\btop\s+(\d{1,4})\b", _normalize_text(question))
|
| 78 |
+
if not match:
|
| 79 |
+
return None
|
| 80 |
+
return max(1, min(1000, int(match.group(1))))
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _extract_metric(question: str) -> tuple[str, str]:
|
| 84 |
+
text = _normalize_text(question)
|
| 85 |
+
if _contains_any(text, ("count", "how many", "number of")):
|
| 86 |
+
return "COUNT(*)", "items_count"
|
| 87 |
+
if _contains_any(text, ("average", "avg", "mean")):
|
| 88 |
+
return "AVG(e.sum)", "avg_amount"
|
| 89 |
+
if _contains_any(text, ("minimum", "lowest", "min ")):
|
| 90 |
+
return "MIN(e.sum)", "min_amount"
|
| 91 |
+
if _contains_any(text, ("maximum", "highest", "max ")):
|
| 92 |
+
return "MAX(e.sum)", "max_amount"
|
| 93 |
+
return "SUM(e.sum)", "total_amount"
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _extract_dimension(question: str) -> str | None:
|
| 97 |
+
text = _normalize_text(question)
|
| 98 |
+
if _contains_any(text, ("category", "categories")):
|
| 99 |
+
return "category"
|
| 100 |
+
if _contains_any(text, ("supplier", "suppliers", "vendor", "vendors")):
|
| 101 |
+
return "supplier"
|
| 102 |
+
if _contains_any(text, ("user", "users", "person")):
|
| 103 |
+
return "user"
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _build_expenses_aggregate_sql(payload: SqlGenerationRequest) -> str:
|
| 108 |
+
question = _normalize_text(payload.question)
|
| 109 |
+
metric_expr, metric_alias = _extract_metric(question)
|
| 110 |
+
dimension = _extract_dimension(question)
|
| 111 |
+
|
| 112 |
+
select_parts = []
|
| 113 |
+
joins = []
|
| 114 |
+
group_by = []
|
| 115 |
+
|
| 116 |
+
if dimension == "category":
|
| 117 |
+
select_parts.append("c.name AS category_name")
|
| 118 |
+
joins.append("JOIN categories AS c ON c.id = e.category_id")
|
| 119 |
+
group_by.append("c.id, c.name")
|
| 120 |
+
elif dimension == "supplier":
|
| 121 |
+
select_parts.append("s.name AS supplier_name")
|
| 122 |
+
joins.append("JOIN suppliers AS s ON s.id = e.supplier_id")
|
| 123 |
+
group_by.append("s.id, s.name")
|
| 124 |
+
elif dimension == "user":
|
| 125 |
+
select_parts.append("u.name AS user_name")
|
| 126 |
+
joins.append("JOIN users AS u ON u.id = e.user_id")
|
| 127 |
+
group_by.append("u.id, u.name")
|
| 128 |
+
|
| 129 |
+
select_parts.append(f"{metric_expr} AS {metric_alias}")
|
| 130 |
+
|
| 131 |
+
filters = []
|
| 132 |
+
month_filter = _extract_month_filter(question)
|
| 133 |
+
if month_filter:
|
| 134 |
+
start, end = month_filter
|
| 135 |
+
filters.append(f"e.date BETWEEN '{start}' AND '{end}'")
|
| 136 |
+
|
| 137 |
+
where_clause = f" WHERE {' AND '.join(filters)}" if filters else ""
|
| 138 |
+
join_clause = f" {' '.join(joins)}" if joins else ""
|
| 139 |
+
group_clause = f" GROUP BY {', '.join(group_by)}" if group_by else ""
|
| 140 |
+
|
| 141 |
+
order_direction = "ASC" if " asc" in question or "ascending" in question else "DESC"
|
| 142 |
+
order_clause = f" ORDER BY {metric_alias} {order_direction}"
|
| 143 |
+
|
| 144 |
+
top_limit = _extract_top_limit(question)
|
| 145 |
+
final_limit = top_limit if top_limit is not None else payload.limit
|
| 146 |
+
|
| 147 |
+
return (
|
| 148 |
+
f"SELECT {', '.join(select_parts)} "
|
| 149 |
+
f"FROM expenses AS e"
|
| 150 |
+
f"{join_clause}"
|
| 151 |
+
f"{where_clause}"
|
| 152 |
+
f"{group_clause}"
|
| 153 |
+
f"{order_clause}"
|
| 154 |
+
f" LIMIT {final_limit}"
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _build_expenses_detail_sql(payload: SqlGenerationRequest) -> str:
|
| 159 |
+
question = _normalize_text(payload.question)
|
| 160 |
+
include_category = _contains_any(question, ("category", "categories"))
|
| 161 |
+
include_supplier = _contains_any(question, ("supplier", "suppliers", "vendor", "vendors"))
|
| 162 |
+
include_user = _contains_any(question, ("user", "users", "person"))
|
| 163 |
+
|
| 164 |
+
select_parts = ["e.date", "e.sum", "e.notes"]
|
| 165 |
+
joins = []
|
| 166 |
+
|
| 167 |
+
if include_category:
|
| 168 |
+
select_parts.append("c.name AS category_name")
|
| 169 |
+
joins.append("JOIN categories AS c ON c.id = e.category_id")
|
| 170 |
+
if include_supplier:
|
| 171 |
+
select_parts.append("s.name AS supplier_name")
|
| 172 |
+
joins.append("JOIN suppliers AS s ON s.id = e.supplier_id")
|
| 173 |
+
if include_user:
|
| 174 |
+
select_parts.append("u.name AS user_name")
|
| 175 |
+
joins.append("JOIN users AS u ON u.id = e.user_id")
|
| 176 |
+
|
| 177 |
+
filters = []
|
| 178 |
+
month_filter = _extract_month_filter(question)
|
| 179 |
+
if month_filter:
|
| 180 |
+
start, end = month_filter
|
| 181 |
+
filters.append(f"e.date BETWEEN '{start}' AND '{end}'")
|
| 182 |
+
|
| 183 |
+
where_clause = f" WHERE {' AND '.join(filters)}" if filters else ""
|
| 184 |
+
join_clause = f" {' '.join(joins)}" if joins else ""
|
| 185 |
+
order_clause = " ORDER BY e.date DESC"
|
| 186 |
+
|
| 187 |
+
return (
|
| 188 |
+
f"SELECT {', '.join(select_parts)} "
|
| 189 |
+
f"FROM expenses AS e"
|
| 190 |
+
f"{join_clause}"
|
| 191 |
+
f"{where_clause}"
|
| 192 |
+
f"{order_clause}"
|
| 193 |
+
f" LIMIT {payload.limit}"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _build_debts_sql(payload: SqlGenerationRequest) -> str:
|
| 198 |
+
question = _normalize_text(payload.question)
|
| 199 |
+
with_user = _contains_any(question, ("user", "users", "person", "name"))
|
| 200 |
+
|
| 201 |
+
select_parts = ["d.date", "d.debt_sum", "d.payment_status"]
|
| 202 |
+
joins = []
|
| 203 |
+
if with_user:
|
| 204 |
+
select_parts.append("u.name AS user_name")
|
| 205 |
+
joins.append("LEFT JOIN users AS u ON u.id = d.user_id")
|
| 206 |
+
|
| 207 |
+
filters = []
|
| 208 |
+
if _contains_any(question, ("unpaid", "not paid", "open debt", "open debts")):
|
| 209 |
+
filters.append("d.payment_status = 'unpaid'")
|
| 210 |
+
elif _contains_any(question, ("paid", "closed debt", "closed debts")):
|
| 211 |
+
filters.append("d.payment_status = 'paid'")
|
| 212 |
+
elif _contains_any(question, ("partial", "partially")):
|
| 213 |
+
filters.append("d.payment_status = 'partial'")
|
| 214 |
+
|
| 215 |
+
month_filter = _extract_month_filter(question)
|
| 216 |
+
if month_filter:
|
| 217 |
+
start, end = month_filter
|
| 218 |
+
filters.append(f"d.date BETWEEN '{start}' AND '{end}'")
|
| 219 |
+
|
| 220 |
+
where_clause = f" WHERE {' AND '.join(filters)}" if filters else ""
|
| 221 |
+
join_clause = f" {' '.join(joins)}" if joins else ""
|
| 222 |
+
order_clause = " ORDER BY d.date DESC"
|
| 223 |
+
|
| 224 |
+
return (
|
| 225 |
+
f"SELECT {', '.join(select_parts)} "
|
| 226 |
+
f"FROM debts AS d"
|
| 227 |
+
f"{join_clause}"
|
| 228 |
+
f"{where_clause}"
|
| 229 |
+
f"{order_clause}"
|
| 230 |
+
f" LIMIT {payload.limit}"
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def _generate_template_sql(payload: SqlGenerationRequest) -> str:
|
| 235 |
+
question = _normalize_text(payload.question)
|
| 236 |
+
|
| 237 |
+
debt_markers = ("debt", "debts", "payment_status", "unpaid", "partial", "paid")
|
| 238 |
+
aggregate_markers = (
|
| 239 |
+
"sum",
|
| 240 |
+
"total",
|
| 241 |
+
"group",
|
| 242 |
+
"grouped",
|
| 243 |
+
"top",
|
| 244 |
+
"count",
|
| 245 |
+
"average",
|
| 246 |
+
"avg",
|
| 247 |
+
"minimum",
|
| 248 |
+
"maximum",
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
if _contains_any(question, debt_markers):
|
| 252 |
+
return _build_debts_sql(payload)
|
| 253 |
+
|
| 254 |
+
if _contains_any(question, aggregate_markers):
|
| 255 |
+
return _build_expenses_aggregate_sql(payload)
|
| 256 |
+
|
| 257 |
+
return _build_expenses_detail_sql(payload)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def _get_sql_generator() -> Any:
|
| 261 |
+
global _SQL_GENERATOR
|
| 262 |
+
|
| 263 |
+
if _SQL_GENERATOR is None:
|
| 264 |
+
from transformers import pipeline
|
| 265 |
+
|
| 266 |
+
model_id = os.getenv("SQL_MODEL", "gaussalgo/T5-LM-Large-text2sql-spider")
|
| 267 |
+
_SQL_GENERATOR = pipeline(
|
| 268 |
+
task="text2text-generation",
|
| 269 |
+
model=model_id,
|
| 270 |
+
tokenizer=model_id,
|
| 271 |
+
device=-1,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
return _SQL_GENERATOR
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def _build_prompt(payload: SqlGenerationRequest) -> str:
|
| 278 |
+
# Optional fallback prompt for transformer model.
|
| 279 |
+
return f"Question: {payload.question} | {DEFAULT_DB_SCHEMA}"
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def _normalize_sql(raw_sql: str, limit: int) -> str:
|
| 283 |
+
sql = (raw_sql or "").strip()
|
| 284 |
+
if not sql:
|
| 285 |
+
raise ValueError("SQL model returned an empty result.")
|
| 286 |
+
|
| 287 |
+
if "```" in sql:
|
| 288 |
+
parts = [part.strip() for part in sql.split("```") if part.strip()]
|
| 289 |
+
sql = parts[-1]
|
| 290 |
+
|
| 291 |
+
upper_sql = sql.upper()
|
| 292 |
+
sql_start = upper_sql.find("SELECT")
|
| 293 |
+
if sql_start == -1:
|
| 294 |
+
raise ValueError("Generated SQL is not a SELECT query.")
|
| 295 |
+
|
| 296 |
+
sql = sql[sql_start:]
|
| 297 |
+
if ";" in sql:
|
| 298 |
+
sql = sql.split(";", 1)[0].strip()
|
| 299 |
+
|
| 300 |
+
upper_sql = sql.upper()
|
| 301 |
+
forbidden = ("INSERT ", "UPDATE ", "DELETE ", "DROP ", "ALTER ", "PRAGMA ", "ATTACH ", "CREATE ", "REPLACE ")
|
| 302 |
+
if any(keyword in upper_sql for keyword in forbidden):
|
| 303 |
+
raise ValueError("Generated SQL contains forbidden statements.")
|
| 304 |
+
|
| 305 |
+
if not upper_sql.startswith("SELECT "):
|
| 306 |
+
raise ValueError("Only SELECT queries are allowed.")
|
| 307 |
+
|
| 308 |
+
aggregate_markers = ("COUNT(", "SUM(", "AVG(", "MIN(", "MAX(")
|
| 309 |
+
has_limit = " LIMIT " in upper_sql
|
| 310 |
+
if not has_limit and not any(marker in upper_sql for marker in aggregate_markers):
|
| 311 |
+
sql = f"{sql} LIMIT {limit}"
|
| 312 |
+
|
| 313 |
+
return sql
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def generate_sql(question: str, limit: int = 200) -> str:
|
| 317 |
+
clean_question = (question or "").strip()
|
| 318 |
+
if not clean_question:
|
| 319 |
+
raise ValueError("Field 'query' is required.")
|
| 320 |
+
|
| 321 |
+
payload = SqlGenerationRequest(
|
| 322 |
+
question=clean_question,
|
| 323 |
+
limit=max(1, min(1000, int(limit))),
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# Primary path: deterministic template engine for core tables.
|
| 327 |
+
template_sql = _generate_template_sql(payload)
|
| 328 |
+
if template_sql:
|
| 329 |
+
return _normalize_sql(template_sql, limit=payload.limit)
|
| 330 |
+
|
| 331 |
+
# Secondary path: optional model fallback.
|
| 332 |
+
if os.getenv("SQL_USE_LLM_FALLBACK", "false").strip().lower() not in {"1", "true", "yes", "on"}:
|
| 333 |
+
raise ValueError("Unable to map query to a supported SQL template.")
|
| 334 |
+
|
| 335 |
+
generator = _get_sql_generator()
|
| 336 |
+
prompt = _build_prompt(payload)
|
| 337 |
+
result = generator(
|
| 338 |
+
prompt,
|
| 339 |
+
max_new_tokens=512,
|
| 340 |
+
do_sample=False,
|
| 341 |
+
num_beams=4,
|
| 342 |
+
truncation=True,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
generated_text = result[0].get("generated_text", "") if result else ""
|
| 346 |
+
return _normalize_sql(generated_text, limit=payload.limit)
|