kaburia commited on
Commit
1eead99
·
1 Parent(s): da4d8cf
Files changed (1) hide show
  1. app.py +93 -140
app.py CHANGED
@@ -1,13 +1,10 @@
1
- # app.py
2
  import os
3
  import uuid
4
  import time
5
  import json
6
  import requests
7
  import gradio as gr
8
-
9
- # ========= Helpers & Context =========
10
- # Ensure your local utils module exposes: session_id, retrieve_context, log_interaction_hf, upload_log_to_hf
11
  import utils.helpers as helpers
12
  from utils.helpers import retrieve_context, log_interaction_hf, upload_log_to_hf
13
 
@@ -15,57 +12,36 @@ from utils.helpers import retrieve_context, log_interaction_hf, upload_log_to_hf
15
  with open("config.json") as f:
16
  config = json.load(f)
17
 
18
- DO_API_KEY = config["do_token"] # DigitalOcean Model Access Key (serverless inference)
19
- HF_TOKEN = "hf_" + config["token"] # Hugging Face token for dataset uploads
20
-
21
- # Stable session id for the whole app lifetime so logs land under a unique folder
22
  session_id = f"{int(time.time())}-{uuid.uuid4().hex[:8]}"
23
- helpers.session_id = session_id # used by your upload_log_to_hf implementation
24
-
25
  BASE_URL = "https://inference.do-ai.run/v1"
26
- UPLOAD_INTERVAL = 5 # upload logs to HF every N turns
27
- REQUEST_TIMEOUT = 60
28
- STREAM_TIMEOUT = 120
29
 
30
- # ========= Network Utils =========
31
  def _auth_headers():
32
- return {
33
- "Authorization": f"Bearer {DO_API_KEY}",
34
- "Content-Type": "application/json",
35
- "Accept": "application/json",
36
- }
37
 
38
  def list_models():
39
- """
40
- Fetch live model IDs from DO; fall back to a deterministic default on failure.
41
- Always return a non-empty list.
42
- """
43
  try:
44
- resp = requests.get(f"{BASE_URL}/models", headers=_auth_headers(), timeout=REQUEST_TIMEOUT)
45
- resp.raise_for_status()
46
- data = resp.json().get("data", [])
47
- ids = [m.get("id") for m in data if m.get("id")]
48
  if ids:
49
  return ids
50
  except Exception as e:
51
  print(f"⚠️ list_models failed: {e}")
52
- # Deterministic fallback
53
  return ["llama3.3-70b-instruct"]
54
 
55
- def _normalize_model_id(model_id: str | None) -> str:
56
- if model_id:
57
- return model_id
58
- return list_models()[0]
59
-
60
- # ========= Inference (non-stream + stream) =========
61
  def gradient_request(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.95):
62
- """
63
- Non-streaming completion (used by lightweight tasks like intent detection).
64
- Self-heals if model_id is not found by retrying with the first available model.
65
- """
66
  url = f"{BASE_URL}/chat/completions"
 
 
67
  payload = {
68
- "model": _normalize_model_id(model_id),
69
  "messages": [{"role": "user", "content": prompt}],
70
  "max_tokens": max_tokens,
71
  "temperature": temperature,
@@ -73,42 +49,39 @@ def gradient_request(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.
73
  }
74
  for attempt in range(3):
75
  try:
76
- resp = requests.post(url, headers=_auth_headers(), json=payload, timeout=REQUEST_TIMEOUT)
77
  if resp.status_code == 404:
78
- # Model not found → pick first available model and retry once
79
  ids = list_models()
80
- if ids and payload["model"] not in ids:
81
  payload["model"] = ids[0]
82
  continue
83
  resp.raise_for_status()
84
  j = resp.json()
85
  return j["choices"][0]["message"]["content"].strip()
86
  except requests.HTTPError as e:
87
- body = getattr(e.response, "text", str(e))
88
- raise RuntimeError(f"Inference error ({e.response.status_code}): {body}") from e
89
  except requests.RequestException as e:
90
  if attempt == 2:
91
  raise
92
- time.sleep(0.5)
93
  raise RuntimeError("Exhausted retries")
94
 
95
  def gradient_stream(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.95):
96
- """
97
- Streaming generator yielding content chunks.
98
- Emits keepalives if the server is quiet for >3s.
99
- """
100
  url = f"{BASE_URL}/chat/completions"
 
 
101
  payload = {
102
- "model": _normalize_model_id(model_id),
103
  "messages": [{"role": "user", "content": prompt}],
104
  "max_tokens": max_tokens,
105
  "temperature": temperature,
106
  "top_p": top_p,
107
  "stream": True,
108
  }
109
-
 
110
  try:
111
- with requests.post(url, headers=_auth_headers(), json=payload, stream=True, timeout=STREAM_TIMEOUT) as r:
112
  if r.status_code != 200:
113
  try:
114
  err_txt = r.text
@@ -116,40 +89,39 @@ def gradient_stream(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.9
116
  err_txt = "<no body>"
117
  raise RuntimeError(f"HTTP {r.status_code}: {err_txt}")
118
 
119
- last_token_ts = time.time()
120
- for raw in r.iter_lines(decode_unicode=True):
121
- if raw is None or raw == b"" or raw == "":
122
- if time.time() - last_token_ts > 3:
123
- last_token_ts = time.time()
124
- yield "" # visual keepalive (no-op for UI)
125
- continue
126
- if not raw.startswith("data: "):
127
- continue
128
- data = raw[6:].strip()
129
- if data == "[DONE]":
130
- break
131
- try:
132
- chunk = json.loads(data)
133
- delta = chunk["choices"][0]["delta"]
134
- content = delta.get("content", "")
135
- if content:
136
- last_token_ts = time.time()
137
- yield content
138
- except Exception:
139
- continue
140
  except Exception as e:
141
- raise
142
 
143
  def gradient_complete(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.95):
144
  url = f"{BASE_URL}/chat/completions"
145
  payload = {
146
- "model": _normalize_model_id(model_id),
147
  "messages": [{"role": "user", "content": prompt}],
148
  "max_tokens": max_tokens,
149
  "temperature": temperature,
150
  "top_p": top_p,
151
  }
152
- r = requests.post(url, headers=_auth_headers(), json=payload, timeout=REQUEST_TIMEOUT)
153
  if r.status_code != 200:
154
  raise RuntimeError(f"HTTP {r.status_code}: {r.text}")
155
  j = r.json()
@@ -157,10 +129,6 @@ def gradient_complete(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0
157
 
158
  # ========= Lightweight Intent Detection =========
159
  def detect_intent(model_id, message: str) -> str:
160
- """
161
- Classify as 'small_talk' or 'info_query'.
162
- Fail-open to 'info_query' on any issue.
163
- """
164
  try:
165
  out = gradient_request(
166
  model_id,
@@ -174,26 +142,28 @@ def detect_intent(model_id, message: str) -> str:
174
  print(f"⚠️ detect_intent failed: {e}")
175
  return "info_query"
176
 
177
- # ========= Gradio App =========
178
  with gr.Blocks(title="Gradient AI Chat") as demo:
 
179
  turn_counter = gr.State(0)
180
 
181
  gr.Markdown("## Gradient AI Chat")
182
  gr.Markdown("Select a model and ask your question.")
183
 
 
184
  with gr.Row():
185
  model_drop = gr.Dropdown(choices=[], label="Select Model")
186
  system_msg = gr.Textbox(
187
- value="You are a faithful assistant. Prefer provided context, but answer helpfully if none is available.",
188
  label="System message"
189
  )
190
 
191
  with gr.Row():
192
  max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens")
193
  temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature")
194
- top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Topp")
195
 
196
- # IMPORTANT: tuples mode we must pass and replace tuples, not mutate them
197
  chatbot = gr.Chatbot(height=500, type="tuples")
198
  msg = gr.Textbox(label="Your message")
199
 
@@ -213,8 +183,8 @@ with gr.Blocks(title="Gradient AI Chat") as demo:
213
  # --- Load models into dropdown at startup
214
  def load_models():
215
  ids = list_models()
216
- # value must be in choices; guarantee both
217
- return gr.Dropdown.update(choices=ids, value=ids[0])
218
 
219
  demo.load(load_models, outputs=[model_drop])
220
 
@@ -229,87 +199,70 @@ with gr.Blocks(title="Gradient AI Chat") as demo:
229
 
230
  # --- Event handlers
231
  def user(user_message, chat_history):
232
- chat_history = chat_history or []
233
- # Append a tuple and return
234
- chat_history = list(chat_history) + [(user_message, "")]
235
- return "", chat_history
236
 
237
  def bot(chat_history, current_turn_count, model_id, system_message, max_tokens, temperature, top_p):
238
- """
239
- Single, clean streaming pass. Replace tuples; never mutate in place.
240
- """
241
- if not chat_history:
242
- # Shouldn't happen, but stay defensive
243
- yield chat_history, (current_turn_count or 0)
244
- return
245
-
246
  user_message = chat_history[-1][0]
247
 
248
- # Intent (optional; keeps your original flow)
249
  intent = detect_intent(model_id, user_message)
250
-
251
- # Build prompt with a safe fallback when RAG returns nothing
252
- context = ""
253
- if intent != "small_talk":
254
  try:
255
- context = retrieve_context(user_message, p=5, threshold=0.5) or ""
256
  except Exception as e:
257
  print(f"⚠️ retrieve_context failed: {e}")
258
  context = ""
 
 
 
 
 
259
 
260
- if intent == "small_talk":
261
- full_prompt = f"[System]: Friendly chat.\n[User]: {user_message}\n[Assistant]: "
262
- else:
263
- if context.strip():
264
- full_prompt = (
265
- f"[System]: {system_message}\n"
266
- "Use the provided context verbatim; if context is insufficient, answer directly.\n\n"
267
- f"Context:\n{context}\n\nQuestion: {user_message}\n"
268
- )
269
- else:
270
- # No context → do not block the model
271
- full_prompt = f"[System]: {system_message}\nQuestion: {user_message}\n"
272
-
273
- # Seed assistant bubble (replace tuple, don’t mutate)
274
- chat_history = list(chat_history)
275
- chat_history[-1] = (chat_history[-1][0], "")
276
- yield chat_history, (current_turn_count or 0)
277
 
278
- # Stream with fallback
279
  try:
280
  received_any = False
281
- buffer = ""
282
-
283
  for token in gradient_stream(model_id, full_prompt, max_tokens, temperature, top_p):
284
- if token:
285
  received_any = True
286
- buffer += token
287
- chat_history[-1] = (chat_history[-1][0], buffer)
288
- yield chat_history, (current_turn_count or 0)
289
-
290
  if not received_any:
291
- text = gradient_complete(model_id, full_prompt, max_tokens, temperature, top_p)
292
- chat_history[-1] = (chat_history[-1][0], text)
293
- yield chat_history, (current_turn_count or 0)
294
-
295
  except Exception as e:
296
- chat_history[-1] = (chat_history[-1][0], f"⚠️ Inference failed: {e}")
297
- yield chat_history, (current_turn_count or 0)
298
- return
299
-
300
- # Logging & periodic upload (once per turn)
 
 
 
 
 
 
 
301
  try:
302
  log_interaction_hf(user_message, chat_history[-1][1])
303
  except Exception as e:
304
  print(f"⚠️ log_interaction_hf failed: {e}")
305
 
306
  new_turn_count = (current_turn_count or 0) + 1
 
307
  if new_turn_count % UPLOAD_INTERVAL == 0:
308
  try:
309
- upload_log_to_hf(HF_TOKEN) # IMPORTANT: HF token, not DO
310
  except Exception as e:
311
  print(f"❌ Log upload failed: {e}")
312
 
 
313
  yield chat_history, new_turn_count
314
 
315
  # Wiring (streaming generators supported)
@@ -339,4 +292,4 @@ with gr.Blocks(title="Gradient AI Chat") as demo:
339
 
340
  if __name__ == "__main__":
341
  # On HF Spaces, don't use share=True. Also disable API page to avoid schema churn.
342
- demo.launch(show_api=False)
 
 
1
  import os
2
  import uuid
3
  import time
4
  import json
5
  import requests
6
  import gradio as gr
7
+ import time
 
 
8
  import utils.helpers as helpers
9
  from utils.helpers import retrieve_context, log_interaction_hf, upload_log_to_hf
10
 
 
12
  with open("config.json") as f:
13
  config = json.load(f)
14
 
15
+ DO_API_KEY = config["do_token"]
16
+ token_ = config['token']
17
+ HF_TOKEN = 'hf_' + token_
 
18
  session_id = f"{int(time.time())}-{uuid.uuid4().hex[:8]}"
19
+ helpers.session_id = session_id
 
20
  BASE_URL = "https://inference.do-ai.run/v1"
21
+ UPLOAD_INTERVAL = 5
 
 
22
 
23
+ # ========= Inference Utilities =========
24
  def _auth_headers():
25
+ return {"Authorization": f"Bearer {DO_API_KEY}", "Content-Type": "application/json"}
 
 
 
 
26
 
27
  def list_models():
 
 
 
 
28
  try:
29
+ r = requests.get(f"{BASE_URL}/models", headers=_auth_headers(), timeout=15)
30
+ r.raise_for_status()
31
+ data = r.json().get("data", [])
32
+ ids = [m["id"] for m in data]
33
  if ids:
34
  return ids
35
  except Exception as e:
36
  print(f"⚠️ list_models failed: {e}")
 
37
  return ["llama3.3-70b-instruct"]
38
 
 
 
 
 
 
 
39
  def gradient_request(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.95):
 
 
 
 
40
  url = f"{BASE_URL}/chat/completions"
41
+ if not model_id:
42
+ model_id = list_models()[0]
43
  payload = {
44
+ "model": model_id,
45
  "messages": [{"role": "user", "content": prompt}],
46
  "max_tokens": max_tokens,
47
  "temperature": temperature,
 
49
  }
50
  for attempt in range(3):
51
  try:
52
+ resp = requests.post(url, headers=_auth_headers(), json=payload, timeout=30)
53
  if resp.status_code == 404:
 
54
  ids = list_models()
55
+ if model_id not in ids and ids:
56
  payload["model"] = ids[0]
57
  continue
58
  resp.raise_for_status()
59
  j = resp.json()
60
  return j["choices"][0]["message"]["content"].strip()
61
  except requests.HTTPError as e:
62
+ msg = getattr(e.response, "text", str(e))
63
+ raise RuntimeError(f"Inference error ({e.response.status_code}): {msg}") from e
64
  except requests.RequestException as e:
65
  if attempt == 2:
66
  raise
 
67
  raise RuntimeError("Exhausted retries")
68
 
69
  def gradient_stream(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.95):
 
 
 
 
70
  url = f"{BASE_URL}/chat/completions"
71
+ if not model_id:
72
+ model_id = list_models()[0]
73
  payload = {
74
+ "model": model_id,
75
  "messages": [{"role": "user", "content": prompt}],
76
  "max_tokens": max_tokens,
77
  "temperature": temperature,
78
  "top_p": top_p,
79
  "stream": True,
80
  }
81
+
82
+ # Create a generator that yields tokens
83
  try:
84
+ with requests.post(url, headers=_auth_headers(), json=payload, stream=True, timeout=120) as r:
85
  if r.status_code != 200:
86
  try:
87
  err_txt = r.text
 
89
  err_txt = "<no body>"
90
  raise RuntimeError(f"HTTP {r.status_code}: {err_txt}")
91
 
92
+ buffer = ""
93
+ for line in r.iter_lines():
94
+ if line:
95
+ decoded_line = line.decode('utf-8')
96
+ if decoded_line.startswith('data:'):
97
+ data = decoded_line[5:].strip()
98
+ if data == '[DONE]':
99
+ break
100
+ try:
101
+ json_data = json.loads(data)
102
+ if 'choices' in json_data:
103
+ for choice in json_data['choices']:
104
+ if 'delta' in choice and 'content' in choice['delta']:
105
+ content = choice['delta']['content']
106
+ buffer += content
107
+ yield content
108
+ except json.JSONDecodeError:
109
+ continue
110
+ if not buffer:
111
+ yield "No response received from the model."
 
112
  except Exception as e:
113
+ raise RuntimeError(f"Streaming error: {str(e)}")
114
 
115
  def gradient_complete(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.95):
116
  url = f"{BASE_URL}/chat/completions"
117
  payload = {
118
+ "model": model_id,
119
  "messages": [{"role": "user", "content": prompt}],
120
  "max_tokens": max_tokens,
121
  "temperature": temperature,
122
  "top_p": top_p,
123
  }
124
+ r = requests.post(url, headers=_auth_headers(), json=payload, timeout=60)
125
  if r.status_code != 200:
126
  raise RuntimeError(f"HTTP {r.status_code}: {r.text}")
127
  j = r.json()
 
129
 
130
  # ========= Lightweight Intent Detection =========
131
  def detect_intent(model_id, message: str) -> str:
 
 
 
 
132
  try:
133
  out = gradient_request(
134
  model_id,
 
142
  print(f"⚠️ detect_intent failed: {e}")
143
  return "info_query"
144
 
145
+ # ========= App Logic (Gradio Blocks) =========
146
  with gr.Blocks(title="Gradient AI Chat") as demo:
147
+ # Keep a reactive turn counter in session state
148
  turn_counter = gr.State(0)
149
 
150
  gr.Markdown("## Gradient AI Chat")
151
  gr.Markdown("Select a model and ask your question.")
152
 
153
+ # Model dropdown will be populated at runtime with live IDs
154
  with gr.Row():
155
  model_drop = gr.Dropdown(choices=[], label="Select Model")
156
  system_msg = gr.Textbox(
157
+ value="You are a faithful assistant. Use only the provided context.",
158
  label="System message"
159
  )
160
 
161
  with gr.Row():
162
  max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens")
163
  temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature")
164
+ top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
165
 
166
+ # Use tuples to silence deprecation warning in current Gradio
167
  chatbot = gr.Chatbot(height=500, type="tuples")
168
  msg = gr.Textbox(label="Your message")
169
 
 
183
  # --- Load models into dropdown at startup
184
  def load_models():
185
  ids = list_models()
186
+ default = ids[0] if ids else None
187
+ return gr.Dropdown.update(choices=ids, value=default)
188
 
189
  demo.load(load_models, outputs=[model_drop])
190
 
 
199
 
200
  # --- Event handlers
201
  def user(user_message, chat_history):
202
+ # Seed a new assistant message for streaming
203
+ return "", (chat_history + [[user_message, ""]])
 
 
204
 
205
  def bot(chat_history, current_turn_count, model_id, system_message, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
 
206
  user_message = chat_history[-1][0]
207
 
208
+ # Build prompt
209
  intent = detect_intent(model_id, user_message)
210
+ if intent == "small_talk":
211
+ full_prompt = f"[System]: Friendly chat.\n[User]: {user_message}\n[Assistant]: "
212
+ else:
 
213
  try:
214
+ context = retrieve_context(user_message, p=5, threshold=0.5)
215
  except Exception as e:
216
  print(f"⚠️ retrieve_context failed: {e}")
217
  context = ""
218
+ full_prompt = (
219
+ f"[System]: {system_message}\n"
220
+ "Use only the provided context. Quote verbatim; no inference.\n\n"
221
+ f"Context:\n{context}\n\nQuestion: {user_message}\n"
222
+ )
223
 
224
+ # Initialize assistant message to empty string and update chat history
225
+ chat_history[-1][1] = ""
226
+ yield chat_history, current_turn_count
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
+ # Attempt to stream the response
229
  try:
230
  received_any = False
 
 
231
  for token in gradient_stream(model_id, full_prompt, max_tokens, temperature, top_p):
232
+ if token: # Skip empty tokens
233
  received_any = True
234
+ chat_history[-1][1] += token
235
+ yield chat_history, current_turn_count
236
+ # If we didn't receive any tokens, fall back to non-streaming
 
237
  if not received_any:
238
+ raise RuntimeError("Streaming returned no tokens; falling back.")
 
 
 
239
  except Exception as e:
240
+ print(f"⚠️ Streaming failed: {e}")
241
+ try:
242
+ # Fall back to non-streaming
243
+ response = gradient_complete(model_id, full_prompt, max_tokens, temperature, top_p)
244
+ chat_history[-1][1] = response
245
+ yield chat_history, current_turn_count
246
+ except Exception as e2:
247
+ chat_history[-1][1] = f"⚠️ Inference failed: {e2}"
248
+ yield chat_history, current_turn_count
249
+ return
250
+
251
+ # After successful response, log and update turn counter
252
  try:
253
  log_interaction_hf(user_message, chat_history[-1][1])
254
  except Exception as e:
255
  print(f"⚠️ log_interaction_hf failed: {e}")
256
 
257
  new_turn_count = (current_turn_count or 0) + 1
258
+ # Periodically upload logs
259
  if new_turn_count % UPLOAD_INTERVAL == 0:
260
  try:
261
+ upload_log_to_hf(HF_TOKEN)
262
  except Exception as e:
263
  print(f"❌ Log upload failed: {e}")
264
 
265
+ # Update the state with the new turn count
266
  yield chat_history, new_turn_count
267
 
268
  # Wiring (streaming generators supported)
 
292
 
293
  if __name__ == "__main__":
294
  # On HF Spaces, don't use share=True. Also disable API page to avoid schema churn.
295
+ demo.launch(show_api=False)