saadkhi commited on
Commit
ac4a697
Β·
verified Β·
1 Parent(s): ce4befd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -21
app.py CHANGED
@@ -6,9 +6,10 @@ from fastapi import FastAPI
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
 
 
9
  torch.set_num_threads(1)
10
 
11
- app = FastAPI()
12
 
13
  BASE_MODEL = "distilgpt2"
14
 
@@ -22,10 +23,10 @@ model.eval()
22
  print("Model ready")
23
 
24
  # ─────────────────────────
25
- # Request schema
26
  # ─────────────────────────
27
  class Query(BaseModel):
28
- question: str
29
 
30
  # ─────────────────────────
31
  # SQL FILTER
@@ -37,30 +38,24 @@ SQL_KEYWORDS = [
37
  ]
38
 
39
  def is_sql_related(text):
40
- text = text.lower()
41
- return any(k in text for k in SQL_KEYWORDS)
42
 
43
  # ─────────────────────────
44
- # Endpoint
45
  # ─────────────────────────
46
- @app.post("/generate-sql")
47
- def generate_sql(data: Query):
48
-
49
- user_input = data.question
50
 
 
51
  if not user_input.strip():
52
- return {"error": "Empty input"}
53
 
54
  if not is_sql_related(user_input):
55
- return {"error": "Only SQL-related queries allowed"}
56
-
57
- prompt = f"""
58
- You are an expert SQL generator.
59
- Only output SQL query.
60
 
61
- User: {user_input}
62
- SQL:
63
- """
64
 
65
  inputs = tokenizer(prompt, return_tensors="pt")
66
 
@@ -74,7 +69,18 @@ SQL:
74
  )
75
 
76
  text = tokenizer.decode(output[0], skip_special_tokens=True)
77
-
78
  result = text.split("SQL:")[-1].strip().split("\n")[0]
79
 
80
- return {"sql": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
 
9
+ # Optimize CPU
10
  torch.set_num_threads(1)
11
 
12
+ app = FastAPI(title="SQL Generator API")
13
 
14
  BASE_MODEL = "distilgpt2"
15
 
 
23
  print("Model ready")
24
 
25
  # ─────────────────────────
26
+ # Request Schema
27
  # ─────────────────────────
28
  class Query(BaseModel):
29
+ text: str
30
 
31
  # ─────────────────────────
32
  # SQL FILTER
 
38
  ]
39
 
40
  def is_sql_related(text):
41
+ return any(k in text.lower() for k in SQL_KEYWORDS)
 
42
 
43
  # ─────────────────────────
44
+ # Generator
45
  # ─────────────────────────
46
+ SYSTEM_PROMPT = """
47
+ You are an expert SQL generator.
48
+ Only output SQL query.
49
+ """
50
 
51
+ def generate_sql(user_input: str):
52
  if not user_input.strip():
53
+ return "Empty input."
54
 
55
  if not is_sql_related(user_input):
56
+ return "Only SQL-related queries allowed."
 
 
 
 
57
 
58
+ prompt = f"{SYSTEM_PROMPT}\nUser: {user_input}\nSQL:"
 
 
59
 
60
  inputs = tokenizer(prompt, return_tensors="pt")
61
 
 
69
  )
70
 
71
  text = tokenizer.decode(output[0], skip_special_tokens=True)
 
72
  result = text.split("SQL:")[-1].strip().split("\n")[0]
73
 
74
+ return result
75
+
76
+ # ─────────────────────────
77
+ # Routes
78
+ # ─────────────────────────
79
+ @app.get("/")
80
+ def root():
81
+ return {"status": "API is running"}
82
+
83
+ @app.post("/generate")
84
+ def generate(query: Query):
85
+ result = generate_sql(query.text)
86
+ return {"result": result}