saadkhi commited on
Commit
00c8a57
Β·
verified Β·
1 Parent(s): 08d2633

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -58
app.py CHANGED
@@ -1,95 +1,108 @@
1
  # app.py
2
  import torch
3
  import gradio as gr
4
- from unsloth import FastLanguageModel
 
5
 
6
  # ────────────────────────────────────────────────────────────────
7
- # Configuration - change here if needed
8
  # ────────────────────────────────────────────────────────────────
9
- MAX_NEW_TOKENS = 96
10
- TEMPERATURE = 0.0 # 0.0 = greedy decoding = fastest
11
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
12
- LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
 
 
 
 
13
 
14
  # ────────────────────────────────────────────────────────────────
15
- print("Loading model with Unsloth...")
16
- model, tokenizer = FastLanguageModel.from_pretrained(
17
- model_name=BASE_MODEL,
18
- max_seq_length=2048,
19
- dtype=None, # auto-detect (bf16 on GPU)
20
- load_in_4bit=True,
 
 
21
  )
22
 
23
- print("Loading LoRA adapters...")
24
- model = FastLanguageModel.get_peft_model(
25
- model,
26
- r=64, # your original rank
27
- target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
28
- lora_alpha=128,
29
- lora_dropout=0,
30
- bias="none",
31
- use_gradient_checkpointing="unsloth",
32
  )
33
 
34
- print("Merging LoRA and preparing for inference...")
35
- model = FastLanguageModel.for_inference(model) # important! activates 2x faster kernels
 
 
 
 
 
36
 
37
- # Optional - compile can give additional 20-60% speedup (PyTorch 2.0+)
38
- if torch.cuda.is_available() and torch.__version__ >= "2.0":
39
- print("Compiling model...")
40
- model = torch.compile(model, mode="reduce-overhead")
 
 
 
41
 
 
42
  print("Model ready!")
43
 
44
  # ────────────────────────────────────────────────────────────────
45
  def generate_sql(prompt: str):
46
- # Very clean chat template usage
47
  messages = [{"role": "user", "content": prompt}]
48
 
49
  inputs = tokenizer.apply_chat_template(
50
  messages,
51
- tokenize=True,
52
- add_generation_prompt=True,
53
- return_tensors="pt"
54
- ).to("cuda" if torch.cuda.is_available() else "cpu")
55
-
56
- outputs = model.generate(
57
- input_ids=inputs,
58
- max_new_tokens=MAX_NEW_TOKENS,
59
- temperature=TEMPERATURE,
60
- do_sample=(TEMPERATURE > 0.01),
61
- use_cache=True,
62
- pad_token_id=tokenizer.eos_token_id,
63
- )
 
 
64
 
65
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
-
67
- # Try to cut after assistant's answer
68
  if "<|assistant|>" in response:
69
  response = response.split("<|assistant|>", 1)[-1].strip()
70
- if "<|end|>" in response:
71
- response = response.split("<|end|>")[0].strip()
72
 
73
  return response
74
 
75
-
76
  # ────────────────────────────────────────────────────────────────
77
  demo = gr.Interface(
78
- fn=generate_sql,
79
- inputs=gr.Textbox(
80
- label="Ask SQL related question",
81
- placeholder="Show me all employees with salary > 50000...",
82
- lines=3,
83
  ),
84
- outputs=gr.Textbox(label="Generated SQL / Answer"),
85
- title="SQL Chat Assistant (Phi-3-mini fine-tuned)",
86
- description="Fast version using Unsloth",
87
- examples=[
88
- ["Find all duplicate emails in users table"],
89
- ["Get top 5 highest paid employees"],
90
- ["How many orders per customer last month?"],
91
  ],
92
- allow_flagging="never",
93
  )
94
 
95
  if __name__ == "__main__":
 
1
  # app.py
2
  import torch
3
  import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
+ from peft import PeftModel
6
 
7
  # ────────────────────────────────────────────────────────────────
8
+ # Configuration - fastest practical settings
9
  # ────────────────────────────────────────────────────────────────
10
+
 
11
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
12
+ LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
13
+
14
+ MAX_NEW_TOKENS = 180 # ← keep reasonable
15
+ TEMPERATURE = 0.0 # greedy = fastest & most deterministic
16
+ DO_SAMPLE = False # no sampling = faster
17
 
18
  # ────────────────────────────────────────────────────────────────
19
+ # 4-bit quantization config (this is the key speedup)
20
+ # ────────────────────────────────────────────────────────────────
21
+
22
+ bnb_config = BitsAndBytesConfig(
23
+ load_in_4bit = True,
24
+ bnb_4bit_quant_type = "nf4", # "nf4" usually fastest + good quality
25
+ bnb_4bit_use_double_quant = True, # nested quantization β†’ extra memory saving
26
+ bnb_4bit_compute_dtype = torch.bfloat16 # fastest compute type on modern GPUs
27
  )
28
 
29
+ print("Loading quantized base model...")
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ BASE_MODEL,
32
+ quantization_config = bnb_config,
33
+ device_map = "auto", # auto = best available (cuda > cpu)
34
+ trust_remote_code = True,
35
+ torch_dtype = torch.bfloat16
 
 
36
  )
37
 
38
+ print("Loading LoRA adapters...")
39
+ model = PeftModel.from_pretrained(model, LORA_PATH)
40
+
41
+ # Important: merge LoRA weights into base (faster inference, less overhead)
42
+ model = model.merge_and_unload()
43
+
44
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
45
 
46
+ # Optional: small speedup boost on supported hardware
47
+ if torch.cuda.is_available():
48
+ try:
49
+ import torch.backends.cuda
50
+ torch.backends.cuda.enable_flash_sdp(True) # flash scaled dot product
51
+ except:
52
+ pass
53
 
54
+ model.eval()
55
  print("Model ready!")
56
 
57
  # ────────────────────────────────────────────────────────────────
58
  def generate_sql(prompt: str):
59
+ # Use proper chat template (Phi-3 expects it)
60
  messages = [{"role": "user", "content": prompt}]
61
 
62
  inputs = tokenizer.apply_chat_template(
63
  messages,
64
+ tokenize = True,
65
+ add_generation_prompt = True,
66
+ return_tensors = "pt"
67
+ ).to(model.device)
68
+
69
+ with torch.inference_mode():
70
+ outputs = model.generate(
71
+ input_ids = inputs,
72
+ max_new_tokens = MAX_NEW_TOKENS,
73
+ temperature = TEMPERATURE,
74
+ do_sample = DO_SAMPLE,
75
+ use_cache = True,
76
+ pad_token_id = tokenizer.eos_token_id,
77
+ eos_token_id = tokenizer.eos_token_id,
78
+ )
79
 
80
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
81
+
82
+ # Clean output - try to get only assistant's answer
83
  if "<|assistant|>" in response:
84
  response = response.split("<|assistant|>", 1)[-1].strip()
85
+ response = response.split("<|end|>")[0].strip()
 
86
 
87
  return response
88
 
 
89
  # ────────────────────────────────────────────────────────────────
90
  demo = gr.Interface(
91
+ fn = generate_sql,
92
+ inputs = gr.Textbox(
93
+ label = "Ask SQL related question",
94
+ placeholder = "Show me all employees with salary > 50000...",
95
+ lines = 3
96
  ),
97
+ outputs = gr.Textbox(label="Generated SQL / Answer"),
98
+ title = "SQL Chatbot - Fast Version",
99
+ description = "Phi-3-mini 4bit quantized + LoRA",
100
+ examples = [
101
+ ["Find duplicate emails in users table"],
102
+ ["Top 5 highest paid employees"],
103
+ ["Count orders per customer last month"]
104
  ],
105
+ allow_flagging = "never"
106
  )
107
 
108
  if __name__ == "__main__":