saadkhi commited on
Commit
7f3026b
Β·
verified Β·
1 Parent(s): ab3f3df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -85
app.py CHANGED
@@ -1,59 +1,61 @@
1
- # app.py
2
- # Stable CPU-only Hugging Face Space
3
- # Phi-3-mini + LoRA (NO bitsandbytes, NO SSR issues)
4
-
5
  import warnings
6
- warnings.filterwarnings("ignore", category=FutureWarning)
7
 
8
  import torch
9
  import gradio as gr
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
  from peft import PeftModel
12
 
13
- # ─────────────────────────────────────────────
14
- # Config
15
- # ─────────────────────────────────────────────
16
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct"
17
  LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
18
 
19
  MAX_NEW_TOKENS = 180
20
- TEMPERATURE = 0.0
21
- DO_SAMPLE = False
22
-
23
- # ─────────────────────────────────────────────
24
- # Load model & tokenizer (CPU SAFE)
25
- # ─────────────────────────────────────────────
26
- print("Loading base model on CPU...")
27
-
28
- model = AutoModelForCausalLM.from_pretrained(
29
- BASE_MODEL,
30
- device_map="cpu",
31
- torch_dtype=torch.float32,
32
- trust_remote_code=True,
33
- low_cpu_mem_usage=True,
34
- )
35
 
36
- print("Loading LoRA adapter...")
37
- model = PeftModel.from_pretrained(model, LORA_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- print("Merging LoRA weights...")
40
- model = model.merge_and_unload()
41
 
42
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
43
 
44
- model.eval()
45
- print("Model & tokenizer loaded successfully")
 
 
46
 
47
- # ─────────────────────────────────────────────
48
- # Inference
49
- # ─────────────────────────────────────────────
50
- def generate_sql(question: str) -> str:
51
- if not question or not question.strip():
52
- return "Please enter a SQL-related question."
53
 
54
- messages = [
55
- {"role": "user", "content": question.strip()}
56
- ]
57
 
58
  input_ids = tokenizer.apply_chat_template(
59
  messages,
@@ -66,60 +68,29 @@ def generate_sql(question: str) -> str:
66
  output_ids = model.generate(
67
  input_ids=input_ids,
68
  max_new_tokens=MAX_NEW_TOKENS,
69
- temperature=TEMPERATURE,
70
- do_sample=DO_SAMPLE,
71
  pad_token_id=tokenizer.eos_token_id,
72
- use_cache=True,
73
  )
74
 
75
- response = tokenizer.decode(
76
- output_ids[0],
77
- skip_special_tokens=True
78
- )
79
 
80
- # Clean Phi-3 chat artifacts
81
- for token in ["<|assistant|>", "<|user|>", "<|end|>"]:
82
- if token in response:
83
- response = response.split(token)[-1]
84
 
85
- return response.strip() or "(empty response)"
86
 
87
- # ─────────────────────────────────────────────
88
- # Gradio UI
89
- # ─────────────────────────────────────────────
90
  demo = gr.Interface(
91
  fn=generate_sql,
92
- inputs=gr.Textbox(
93
- label="SQL Question",
94
- placeholder="Find duplicate emails in users table",
95
- lines=3,
96
- ),
97
- outputs=gr.Textbox(
98
- label="Generated SQL",
99
- lines=8,
100
- ),
101
- title="SQL Chat – Phi-3-mini (CPU)",
102
- description=(
103
- "CPU-only Hugging Face Space.\n"
104
- "First response may take 60–180 seconds. "
105
- "Subsequent requests are faster."
106
- ),
107
- examples=[
108
- ["Find duplicate emails in users table"],
109
- ["Top 5 highest paid employees"],
110
- ["Count orders per customer last month"],
111
- ["Delete duplicate rows based on email"],
112
- ],
113
- cache_examples=False,
114
  )
115
 
116
- # ─────────────────────────────────────────────
117
- # Launch
118
- # ─────────────────────────────────────────────
119
  if __name__ == "__main__":
120
- print("Launching Gradio interface...")
121
- demo.launch(
122
- server_name="0.0.0.0",
123
- ssr_mode=False, # important: avoids asyncio FD bug
124
- show_error=True,
125
- )
 
 
 
 
 
1
  import warnings
2
+ warnings.filterwarnings("ignore")
3
 
4
  import torch
5
  import gradio as gr
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from peft import PeftModel
8
 
9
+ # ─────────────────────────────
 
 
10
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct"
11
  LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
12
 
13
  MAX_NEW_TOKENS = 180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ model = None
16
+ tokenizer = None
17
+
18
+ # ─────────────────────────────
19
+ # Lazy load (VERY IMPORTANT)
20
+ # ─────────────────────────────
21
+ def load_model():
22
+ global model, tokenizer
23
+
24
+ if model is not None:
25
+ return
26
+
27
+ print("πŸ”„ Loading model (first request only)...")
28
+
29
+ base = AutoModelForCausalLM.from_pretrained(
30
+ BASE_MODEL,
31
+ device_map="cpu",
32
+ torch_dtype=torch.float16, # lighter
33
+ low_cpu_mem_usage=True,
34
+ trust_remote_code=True,
35
+ )
36
+
37
+ base = PeftModel.from_pretrained(base, LORA_PATH)
38
+
39
+ print("Merging LoRA...")
40
+ model_loaded = base.merge_and_unload()
41
+
42
+ tokenizer_loaded = AutoTokenizer.from_pretrained(BASE_MODEL)
43
+
44
+ model_loaded.eval()
45
 
46
+ model = model_loaded
47
+ tokenizer = tokenizer_loaded
48
 
49
+ print("βœ… Model ready")
50
 
51
+ # ─────────────────────────────
52
+ def generate_sql(question):
53
+ if not question.strip():
54
+ return "Enter a question"
55
 
56
+ load_model()
 
 
 
 
 
57
 
58
+ messages = [{"role": "user", "content": question}]
 
 
59
 
60
  input_ids = tokenizer.apply_chat_template(
61
  messages,
 
68
  output_ids = model.generate(
69
  input_ids=input_ids,
70
  max_new_tokens=MAX_NEW_TOKENS,
71
+ temperature=0.0,
72
+ do_sample=False,
73
  pad_token_id=tokenizer.eos_token_id,
 
74
  )
75
 
76
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
 
77
 
78
+ for t in ["<|assistant|>", "<|user|>", "<|end|>"]:
79
+ if t in response:
80
+ response = response.split(t)[-1]
 
81
 
82
+ return response.strip()
83
 
84
+ # ─────────────────────────────
 
 
85
  demo = gr.Interface(
86
  fn=generate_sql,
87
+ inputs=gr.Textbox(lines=3, label="SQL Question"),
88
+ outputs=gr.Textbox(lines=8, label="SQL"),
89
+ title="SQL Chat Phi-3 CPU",
90
+ description="First request loads model (60-120s)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  )
92
 
93
+ demo.queue(concurrency_count=1, max_size=5)
94
+
 
95
  if __name__ == "__main__":
96
+ demo.launch(server_name="0.0.0.0", show_error=True)