saadkhi commited on
Commit
818d4d0
Β·
verified Β·
1 Parent(s): a729178

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -40
app.py CHANGED
@@ -2,23 +2,21 @@ import warnings
2
  warnings.filterwarnings("ignore")
3
 
4
  import torch
5
- from fastapi import FastAPI
6
- from pydantic import BaseModel
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
-
9
  import gradio as gr
10
- import threading
11
 
12
  torch.set_num_threads(1)
13
 
14
- app = FastAPI(title="SQL Generator API")
15
-
16
  BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
17
 
18
  print("Loading model...")
19
 
 
 
 
 
 
20
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
21
- model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
22
 
23
  model.eval()
24
 
@@ -34,21 +32,34 @@ SQL_KEYWORDS = [
34
  ]
35
 
36
  def is_sql_related(text):
37
- return any(k in text.lower() for k in SQL_KEYWORDS)
 
38
 
 
 
 
39
  SYSTEM_PROMPT = """
40
  You are an expert SQL generator.
41
- Only output SQL query.
 
 
 
 
42
  """
43
 
44
- def generate_sql(user_input: str):
 
45
  if not user_input.strip():
46
  return "Enter SQL question."
47
 
48
  if not is_sql_related(user_input):
49
- return "Only SQL/database questions allowed."
50
 
51
- prompt = f"{SYSTEM_PROMPT}\nUser: {user_input}\nSQL:"
 
 
 
 
52
 
53
  inputs = tokenizer(prompt, return_tensors="pt")
54
 
@@ -62,35 +73,34 @@ def generate_sql(user_input: str):
62
  )
63
 
64
  text = tokenizer.decode(output[0], skip_special_tokens=True)
65
- result = text.split("SQL:")[-1].strip().split("\n")[0]
66
 
67
- return result
 
68
 
69
- # ─────────────────────────
70
- # API
71
- # ─────────────────────────
72
- class Query(BaseModel):
73
- text: str
74
-
75
- @app.get("/")
76
- def root():
77
- return {"status": "API running"}
78
-
79
- @app.post("/generate")
80
- def generate(query: Query):
81
- return {"result": generate_sql(query.text)}
82
 
83
  # ─────────────────────────
84
- # Gradio UI (for testing)
85
  # ─────────────────────────
86
- def run_gradio():
87
- demo = gr.Interface(
88
- fn=generate_sql,
89
- inputs=gr.Textbox(lines=3, label="SQL Question"),
90
- outputs=gr.Textbox(lines=6, label="Generated SQL"),
91
- title="SQL Generator"
92
- )
93
- demo.launch(server_name="0.0.0.0", server_port=7860)
94
-
95
- # Run UI in separate thread
96
- threading.Thread(target=run_gradio).start()
 
 
 
 
 
 
 
 
 
 
 
 
2
  warnings.filterwarnings("ignore")
3
 
4
  import torch
 
 
 
 
5
  import gradio as gr
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
8
  torch.set_num_threads(1)
9
 
 
 
10
  BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
11
 
12
  print("Loading model...")
13
 
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ BASE_MODEL,
16
+ torch_dtype=torch.float32
17
+ )
18
+
19
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
 
20
 
21
  model.eval()
22
 
 
32
  ]
33
 
34
  def is_sql_related(text):
35
+ text = text.lower()
36
+ return any(k in text for k in SQL_KEYWORDS)
37
 
38
+ # ─────────────────────────
39
+ # GENERATION
40
+ # ─────────────────────────
41
  SYSTEM_PROMPT = """
42
  You are an expert SQL generator.
43
+ Rules:
44
+ - Only respond to SQL or database related questions.
45
+ - If the question is not about SQL or databases, refuse.
46
+ - Output ONLY SQL query.
47
+ - Do not explain.
48
  """
49
 
50
+ def generate_sql(user_input):
51
+
52
  if not user_input.strip():
53
  return "Enter SQL question."
54
 
55
  if not is_sql_related(user_input):
56
+ return "I only respond to SQL and database related questions. If you want, I can craft helpful database queries for you."
57
 
58
+ prompt = f"""
59
+ {SYSTEM_PROMPT}
60
+ User request: {user_input}
61
+ SQL:
62
+ """
63
 
64
  inputs = tokenizer(prompt, return_tensors="pt")
65
 
 
73
  )
74
 
75
  text = tokenizer.decode(output[0], skip_special_tokens=True)
 
76
 
77
+ result = text.split("SQL:")[-1].strip()
78
+ result = result.split("\n\n")[0]
79
 
80
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # ─────────────────────────
83
+ # UI
84
  # ─────────────────────────
85
+ demo = gr.Interface(
86
+ fn=generate_sql,
87
+ inputs=gr.Textbox(
88
+ lines=3,
89
+ label="SQL Question",
90
+ placeholder="Find duplicate emails in users table"
91
+ ),
92
+ outputs=gr.Textbox(
93
+ lines=8,
94
+ label="Generated SQL"
95
+ ),
96
+ title="AI SQL Generator (Portfolio Project)",
97
+ description="This model ONLY responds to SQL/database queries.",
98
+ examples=[
99
+ ["Find duplicate emails in users table"],
100
+ ["Top 5 highest paid employees"],
101
+ ["Count orders per customer last month"],
102
+ ["Write a joke about cats"]
103
+ ],
104
+ )
105
+
106
+ demo.launch(server_name="0.0.0.0")