saadkhi commited on
Commit
17bc164
Β·
verified Β·
1 Parent(s): 7ea64d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -47
app.py CHANGED
@@ -2,13 +2,14 @@ import warnings
2
  warnings.filterwarnings("ignore")
3
 
4
  import torch
5
- import gradio as gr
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
8
- # Reduce CPU pressure
9
  torch.set_num_threads(1)
10
 
11
- # βœ… Use lightweight model (IMPORTANT)
 
12
  BASE_MODEL = "distilgpt2"
13
 
14
  print("Loading model...")
@@ -20,6 +21,12 @@ model.eval()
20
 
21
  print("Model ready")
22
 
 
 
 
 
 
 
23
  # ─────────────────────────
24
  # SQL FILTER
25
  # ─────────────────────────
@@ -34,28 +41,26 @@ def is_sql_related(text):
34
  return any(k in text for k in SQL_KEYWORDS)
35
 
36
  # ─────────────────────────
37
- # PROMPT
38
  # ─────────────────────────
39
- SYSTEM_PROMPT = """
40
- You are an expert SQL generator.
41
- Rules:
42
- - Only respond to SQL or database related questions.
43
- - Output ONLY SQL query.
44
- - No explanation.
45
- """
46
 
47
- # ─────────────────────────
48
- # GENERATION
49
- # ─────────────────────────
50
- def generate_sql(user_input):
51
 
52
  if not user_input.strip():
53
- return "Enter SQL question."
54
 
55
  if not is_sql_related(user_input):
56
- return "Only SQL/database questions are allowed."
 
 
 
 
57
 
58
- prompt = f"{SYSTEM_PROMPT}\nUser: {user_input}\nSQL:"
 
 
59
 
60
  inputs = tokenizer(prompt, return_tensors="pt")
61
 
@@ -70,33 +75,6 @@ def generate_sql(user_input):
70
 
71
  text = tokenizer.decode(output[0], skip_special_tokens=True)
72
 
73
- result = text.split("SQL:")[-1].strip()
74
- result = result.split("\n")[0]
75
 
76
- return result
77
-
78
- # ─────────────────────────
79
- # UI
80
- # ─────────────────────────
81
- demo = gr.Interface(
82
- fn=generate_sql,
83
- inputs=gr.Textbox(
84
- lines=3,
85
- label="SQL Question",
86
- placeholder="Find duplicate emails in users table"
87
- ),
88
- outputs=gr.Textbox(
89
- lines=6,
90
- label="Generated SQL"
91
- ),
92
- title="AI SQL Generator (Portfolio Project)",
93
- description="Only SQL/database queries are supported.",
94
- examples=[
95
- ["Find duplicate emails in users table"],
96
- ["Top 5 highest paid employees"],
97
- ["Count orders per customer last month"],
98
- ["Write a joke about cats"]
99
- ],
100
- )
101
-
102
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
2
  warnings.filterwarnings("ignore")
3
 
4
  import torch
5
+ from fastapi import FastAPI
6
+ from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
 
 
9
  torch.set_num_threads(1)
10
 
11
+ app = FastAPI()
12
+
13
  BASE_MODEL = "distilgpt2"
14
 
15
  print("Loading model...")
 
21
 
22
  print("Model ready")
23
 
24
+ # ─────────────────────────
25
+ # Request schema
26
+ # ─────────────────────────
27
+ class Query(BaseModel):
28
+ question: str
29
+
30
  # ─────────────────────────
31
  # SQL FILTER
32
  # ─────────────────────────
 
41
  return any(k in text for k in SQL_KEYWORDS)
42
 
43
  # ─────────────────────────
44
+ # Endpoint
45
  # ─────────────────────────
46
+ @app.post("/generate-sql")
47
+ def generate_sql(data: Query):
 
 
 
 
 
48
 
49
+ user_input = data.question
 
 
 
50
 
51
  if not user_input.strip():
52
+ return {"error": "Empty input"}
53
 
54
  if not is_sql_related(user_input):
55
+ return {"error": "Only SQL-related queries allowed"}
56
+
57
+ prompt = f"""
58
+ You are an expert SQL generator.
59
+ Only output SQL query.
60
 
61
+ User: {user_input}
62
+ SQL:
63
+ """
64
 
65
  inputs = tokenizer(prompt, return_tensors="pt")
66
 
 
75
 
76
  text = tokenizer.decode(output[0], skip_special_tokens=True)
77
 
78
+ result = text.split("SQL:")[-1].strip().split("\n")[0]
 
79
 
80
+ return {"sql": result}