arubaDev commited on
Commit
33cfdca
·
verified ·
1 Parent(s): d793e1b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +344 -0
app.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sqlite3
3
+ import time
4
+ from datetime import datetime
5
+ import gradio as gr
6
+ from huggingface_hub import InferenceClient
7
+
8
+ # ---------------------------
9
+ # Config
10
+ # ---------------------------
11
+ MODELS = {
12
+ "Meta LLaMA 3.1 (8B Instruct)": "meta-llama/Llama-3.1-8B-Instruct",
13
+ "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3",
14
+ }
15
+ HF_TOKEN = os.getenv("HF_TOKEN") # set this in your Space's Secrets
16
+ DB_PATH = "history.db"
17
+
18
+ SYSTEM_DEFAULT = (
19
+ "You are a coding assistant. Respond only with clean and complete code "
20
+ "unless explanation is explicitly requested. Prefer full CRUD scaffolds, "
21
+ "with files, paths, and commands when asked."
22
+ )
23
+
24
+ # ---------------------------
25
+ # DB Setup
26
+ # ---------------------------
27
+ def db():
28
+ conn = sqlite3.connect(DB_PATH)
29
+ conn.execute("PRAGMA journal_mode=WAL;")
30
+ return conn
31
+
32
+ def init_db():
33
+ conn = db()
34
+ cur = conn.cursor()
35
+ cur.execute("""
36
+ CREATE TABLE IF NOT EXISTS sessions (
37
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
38
+ title TEXT NOT NULL,
39
+ created_at TEXT NOT NULL
40
+ )
41
+ """)
42
+ cur.execute("""
43
+ CREATE TABLE IF NOT EXISTS messages (
44
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
45
+ session_id INTEGER NOT NULL,
46
+ role TEXT NOT NULL,
47
+ content TEXT NOT NULL,
48
+ created_at TEXT NOT NULL,
49
+ FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE
50
+ )
51
+ """)
52
+ conn.commit()
53
+ conn.close()
54
+
55
+ def create_session(title: str = "New chat") -> int:
56
+ conn = db()
57
+ cur = conn.cursor()
58
+ cur.execute(
59
+ "INSERT INTO sessions (title, created_at) VALUES (?, ?)",
60
+ (title, datetime.utcnow().isoformat())
61
+ )
62
+ session_id = cur.lastrowid
63
+ conn.commit()
64
+ conn.close()
65
+ return session_id
66
+
67
+ def delete_session(session_id: int):
68
+ conn = db()
69
+ cur = conn.cursor()
70
+ cur.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
71
+ cur.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
72
+ conn.commit()
73
+ conn.close()
74
+
75
+ def list_sessions():
76
+ conn = db()
77
+ cur = conn.cursor()
78
+ cur.execute("SELECT id, title FROM sessions ORDER BY id DESC")
79
+ rows = cur.fetchall()
80
+ conn.close()
81
+ labels = [f"{r[0]} • {r[1]}" for r in rows]
82
+ return labels, rows
83
+
84
+ def get_messages(session_id: int):
85
+ conn = db()
86
+ cur = conn.cursor()
87
+ cur.execute("""
88
+ SELECT role, content FROM messages
89
+ WHERE session_id = ?
90
+ ORDER BY id ASC
91
+ """, (session_id,))
92
+ rows = cur.fetchall()
93
+ conn.close()
94
+ msgs = [{"role": role, "content": content} for (role, content) in rows]
95
+ return msgs
96
+
97
+ def add_message(session_id: int, role: str, content: str):
98
+ conn = db()
99
+ cur = conn.cursor()
100
+ cur.execute(
101
+ "INSERT INTO messages (session_id, role, content, created_at) VALUES (?, ?, ?, ?)",
102
+ (session_id, role, content, datetime.utcnow().isoformat())
103
+ )
104
+ conn.commit()
105
+ conn.close()
106
+
107
+ def update_session_title_if_needed(session_id: int, first_user_text: str):
108
+ conn = db()
109
+ cur = conn.cursor()
110
+ cur.execute("SELECT COUNT(*) FROM messages WHERE session_id=? AND role='user'", (session_id,))
111
+ count_users = cur.fetchone()[0]
112
+ if count_users == 1:
113
+ title = first_user_text.strip().split("\n")[0]
114
+ title = (title[:50] + "…") if len(title) > 50 else title
115
+ cur.execute("UPDATE sessions SET title=? WHERE id=?", (title or "New chat", session_id))
116
+ conn.commit()
117
+ conn.close()
118
+
119
+ # ---------------------------
120
+ # Helpers
121
+ # ---------------------------
122
+ def label_to_id(label: str | None) -> int | None:
123
+ if not label:
124
+ return None
125
+ try:
126
+ return int(label.split("•", 1)[0].strip())
127
+ except Exception:
128
+ return None
129
+
130
+ def build_api_messages(session_id: int, system_message: str):
131
+ msgs = [{"role": "system", "content": system_message.strip()}]
132
+ msgs.extend(get_messages(session_id))
133
+ return msgs
134
+
135
+ def get_client(model_choice: str):
136
+ """Return the right InferenceClient for the chosen model."""
137
+ model_id = MODELS.get(model_choice, list(MODELS.values())[0])
138
+ return InferenceClient(model_id, token=HF_TOKEN)
139
+
140
+ # ---------------------------
141
+ # Gradio Callbacks
142
+ # ---------------------------
143
+ def refresh_sessions_cb():
144
+ labels, _ = list_sessions()
145
+ selected = labels[0] if labels else None
146
+ return gr.update(choices=labels, value=selected)
147
+
148
+ def new_chat_cb():
149
+ sid = create_session("New chat")
150
+ labels, _ = list_sessions()
151
+ selected = next((lbl for lbl in labels if lbl.startswith(f"{sid} ")), None)
152
+ return (gr.update(choices=labels, value=selected), [], "")
153
+
154
+ def load_session_cb(selected_label):
155
+ sid = label_to_id(selected_label)
156
+ if not sid:
157
+ return []
158
+ return get_messages(sid)
159
+
160
+ def delete_chat_cb(selected_label):
161
+ sid = label_to_id(selected_label)
162
+ if sid:
163
+ delete_session(sid)
164
+ labels, _ = list_sessions()
165
+ selected = labels[0] if labels else None
166
+ return gr.update(choices=labels, value=selected), []
167
+
168
+ def send_cb(user_text, selected_label, chatbot_msgs, system_message, max_tokens, temperature, top_p, model_choice):
169
+ sid = label_to_id(selected_label)
170
+ if sid is None:
171
+ sid = create_session("New chat")
172
+ labels, _ = list_sessions()
173
+ selected_label = next((lbl for lbl in labels if lbl.startswith(f"{sid} ")), None)
174
+
175
+ add_message(sid, "user", user_text)
176
+ update_session_title_if_needed(sid, user_text)
177
+
178
+ api_messages = build_api_messages(sid, system_message)
179
+ display_msgs = get_messages(sid)
180
+ display_msgs.append({"role": "assistant", "content": ""})
181
+
182
+ client = get_client(model_choice)
183
+ partial = ""
184
+ try:
185
+ for chunk in client.chat_completion(
186
+ messages=api_messages,
187
+ max_tokens=int(max_tokens),
188
+ temperature=float(temperature),
189
+ top_p=float(top_p),
190
+ stream=True,
191
+ ):
192
+ delta = chunk.choices[0].delta.content or ""
193
+ if delta:
194
+ partial += delta
195
+ display_msgs[-1]["content"] = partial
196
+ yield (display_msgs, "", selected_label)
197
+ add_message(sid, "assistant", partial)
198
+ except Exception as e:
199
+ err = f"⚠️ Error: {str(e)}"
200
+ display_msgs[-1]["content"] = err
201
+ yield (display_msgs, "", selected_label)
202
+
203
+ def regenerate_cb(selected_label, system_message, max_tokens, temperature, top_p, model_choice):
204
+ sid = label_to_id(selected_label)
205
+ if sid is None:
206
+ return [], ""
207
+
208
+ msgs = get_messages(sid)
209
+ if not msgs:
210
+ return [], ""
211
+
212
+ if msgs and msgs[-1]["role"] == "assistant":
213
+ conn = db()
214
+ cur = conn.cursor()
215
+ cur.execute("""
216
+ DELETE FROM messages
217
+ WHERE id = (
218
+ SELECT id FROM messages WHERE session_id=? ORDER BY id DESC LIMIT 1
219
+ )
220
+ """, (sid,))
221
+ conn.commit()
222
+ conn.close()
223
+ msgs = get_messages(sid)
224
+
225
+ api_messages = [{"role": "system", "content": system_message.strip()}] + msgs
226
+ display_msgs = msgs + [{"role": "assistant", "content": ""}]
227
+
228
+ client = get_client(model_choice)
229
+ partial = ""
230
+ try:
231
+ for chunk in client.chat_completion(
232
+ messages=api_messages,
233
+ max_tokens=int(max_tokens),
234
+ temperature=float(temperature),
235
+ top_p=float(top_p),
236
+ stream=True,
237
+ ):
238
+ delta = chunk.choices[0].delta.content or ""
239
+ if delta:
240
+ partial += delta
241
+ display_msgs[-1]["content"] = partial
242
+ yield display_msgs
243
+ add_message(sid, "assistant", partial)
244
+ except Exception as e:
245
+ display_msgs[-1]["content"] = f"⚠️ Error: {str(e)}"
246
+ yield display_msgs
247
+
248
+ # ---------------------------
249
+ # App UI
250
+ # ---------------------------
251
+ init_db()
252
+ labels, _ = list_sessions()
253
+ if not labels:
254
+ first_sid = create_session("New chat")
255
+ labels, _ = list_sessions()
256
+ default_selected = labels[0] if labels else None
257
+
258
+ with gr.Blocks(title="LLaMA/Mistral CRUD Automation (with History)", theme=gr.themes.Soft()) as demo:
259
+ # --- Updated CSS to make ALL buttons green ---
260
+ gr.HTML("""
261
+ <style>
262
+ button {
263
+ background-color: #22c55e !important;
264
+ color: #ffffff !important;
265
+ border: none !important;
266
+ }
267
+ button:hover {
268
+ background-color: #16a34a !important;
269
+ }
270
+ button:focus {
271
+ outline: 2px solid #166534 !important;
272
+ outline-offset: 2px;
273
+ }
274
+ </style>
275
+ """)
276
+
277
+ gr.Markdown("## 🦙🤖 LLaMA & Mistral CRUD Automation — with Persistent History")
278
+
279
+ with gr.Row():
280
+ with gr.Column(scale=1, min_width=260):
281
+ gr.Markdown("### 📁 Sessions")
282
+ session_list = gr.Radio(
283
+ choices=labels,
284
+ value=default_selected,
285
+ label="Your chats",
286
+ interactive=True
287
+ )
288
+ with gr.Row():
289
+ new_btn = gr.Button("➕ New Chat", variant="primary")
290
+ del_btn = gr.Button("🗑️ Delete", variant="stop")
291
+ refresh_btn = gr.Button("🔄 Refresh", variant="secondary")
292
+
293
+ gr.Markdown("### 🤖 Model Selection")
294
+ model_choice = gr.Dropdown(
295
+ choices=list(MODELS.keys()),
296
+ value=list(MODELS.keys())[0],
297
+ label="Choose a model",
298
+ interactive=True
299
+ )
300
+
301
+ gr.Markdown("### ⚙️ Generation Settings")
302
+ system_box = gr.Textbox(
303
+ value=SYSTEM_DEFAULT,
304
+ label="System message",
305
+ lines=4
306
+ )
307
+ max_tokens = gr.Slider(256, 4096, value=1200, step=16, label="Max tokens")
308
+ temperature = gr.Slider(0.0, 2.0, value=0.25, step=0.05, label="Temperature")
309
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
310
+
311
+ with gr.Column(scale=3):
312
+ chatbot = gr.Chatbot(label="Assistant", height=520, type="messages")
313
+ with gr.Row():
314
+ user_box = gr.Textbox(placeholder="Describe your CRUD task…", lines=3, scale=5)
315
+ with gr.Row():
316
+ send_btn = gr.Button("Send ▶️", variant="primary")
317
+ regen_btn = gr.Button("Regenerate 🔁", variant="secondary")
318
+
319
+ # Interactions
320
+ refresh_btn.click(refresh_sessions_cb, outputs=session_list)
321
+ new_btn.click(new_chat_cb, outputs=[session_list, chatbot, user_box])
322
+ del_btn.click(delete_chat_cb, inputs=session_list, outputs=[session_list, chatbot])
323
+ session_list.change(load_session_cb, inputs=session_list, outputs=chatbot)
324
+
325
+ send_btn.click(
326
+ send_cb,
327
+ inputs=[user_box, session_list, chatbot, system_box, max_tokens, temperature, top_p, model_choice],
328
+ outputs=[chatbot, user_box, session_list]
329
+ )
330
+
331
+ user_box.submit(
332
+ send_cb,
333
+ inputs=[user_box, session_list, chatbot, system_box, max_tokens, temperature, top_p, model_choice],
334
+ outputs=[chatbot, user_box, session_list]
335
+ )
336
+
337
+ regen_btn.click(
338
+ regenerate_cb,
339
+ inputs=[session_list, system_box, max_tokens, temperature, top_p, model_choice],
340
+ outputs=chatbot
341
+ )
342
+
343
+ if __name__ == "__main__":
344
+ demo.launch()