VladGeekPro Copilot commited on
Commit
92ce1cc
·
1 Parent(s): 8759264

ChangedModelForQueryGeneration

Browse files

Co-authored-by: Copilot <copilot@github.com>

Files changed (1) hide show
  1. sql_generator.py +9 -9
sql_generator.py CHANGED
@@ -96,7 +96,7 @@ def _get_sql_generator() -> Any:
96
  if _SQL_GENERATOR is None:
97
  from transformers import pipeline
98
 
99
- model_id = os.getenv("SQL_MODEL", "google/flan-t5-base")
100
  _SQL_GENERATOR = pipeline(
101
  task="text2text-generation",
102
  model=model_id,
@@ -109,14 +109,13 @@ def _get_sql_generator() -> Any:
109
 
110
  def _build_prompt(payload: SqlGenerationRequest) -> str:
111
  return (
112
- "You translate user requests into SQLite SELECT queries. "
113
- "Return only SQL without explanations. "
114
- "Use only tables and columns from the schema. "
115
- "Never generate INSERT, UPDATE, DELETE, DROP, ALTER, PRAGMA, ATTACH or CREATE. "
116
- "Prefer explicit JOIN conditions using foreign keys from the schema. "
117
- f"Add LIMIT {payload.limit} when the query is not an aggregate result.\n\n"
118
- f"Schema:\n{DEFAULT_DB_SCHEMA}\n\n"
119
- f"User request:\n{payload.question}\n\n"
120
  "SQL:"
121
  )
122
 
@@ -171,6 +170,7 @@ def generate_sql(question: str, limit: int = 200) -> str:
171
  prompt,
172
  max_new_tokens=256,
173
  do_sample=False,
 
174
  truncation=True,
175
  )
176
 
 
96
  if _SQL_GENERATOR is None:
97
  from transformers import pipeline
98
 
99
+ model_id = os.getenv("SQL_MODEL", "gaussalgo/T5-LM-Large-text2sql-spider")
100
  _SQL_GENERATOR = pipeline(
101
  task="text2text-generation",
102
  model=model_id,
 
109
 
110
  def _build_prompt(payload: SqlGenerationRequest) -> str:
111
  return (
112
+ "Task: generate one SQLite SQL query for the question. "
113
+ "Return only SQL, no markdown and no explanation. "
114
+ "Use only schema objects provided below. "
115
+ "Never generate non-SELECT statements.\n\n"
116
+ f"Question: {payload.question}\n"
117
+ f"Schema: {DEFAULT_DB_SCHEMA}\n"
118
+ f"Constraints: output must start with SELECT; apply LIMIT {payload.limit} when not aggregate.\n"
 
119
  "SQL:"
120
  )
121
 
 
170
  prompt,
171
  max_new_tokens=256,
172
  do_sample=False,
173
+ num_beams=4,
174
  truncation=True,
175
  )
176