Alikestocode commited on
Commit
4ce42e8
·
1 Parent(s): 0e2f6c4

Add GPU estimator, DDG search, and cancel support

Browse files
Files changed (1) hide show
  1. app.py +257 -6
app.py CHANGED
@@ -1,14 +1,23 @@
1
  from __future__ import annotations
2
 
3
  import json
 
4
  import os
5
  import re
 
 
6
  from typing import Any, Dict, List, Tuple, Optional
7
 
8
  import gradio as gr
9
  import spaces
10
  import torch
11
- from transformers import AutoTokenizer, TextIteratorStreamer, pipeline
 
 
 
 
 
 
12
  from threading import Thread
13
  from concurrent.futures import ThreadPoolExecutor
14
 
@@ -18,6 +27,13 @@ try:
18
  except ImportError: # pragma: no cover
19
  HF_HUB_AVAILABLE = False
20
 
 
 
 
 
 
 
 
21
  # Enable optimizations
22
  torch.backends.cuda.matmul.allow_tf32 = True
23
 
@@ -47,6 +63,8 @@ except ImportError:
47
  SamplingParams = None
48
  print("Warning: vLLM not available, falling back to Transformers")
49
 
 
 
50
  # Optional flag to disable vLLM (defaults to true on MIG due to device detection instability)
51
  DISABLE_VLLM = os.environ.get("DISABLE_VLLM", "1" if MIG_VISIBLE else "0") == "1"
52
 
@@ -95,6 +113,88 @@ def _ensure_local_repo(repo_id: str) -> Optional[str]:
95
  return None
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def _start_prefetch_workers(model_names: list[str]):
99
  global PREFETCH_EXECUTOR
100
  if PREFETCH_DISABLED or not HF_HUB_AVAILABLE:
@@ -793,6 +893,10 @@ def _generate_router_plan_streaming_internal(
793
  temperature: float,
794
  top_p: float,
795
  gpu_duration: int,
 
 
 
 
796
  ):
797
  """Internal generator function for streaming token output."""
798
  if not user_task.strip():
@@ -803,10 +907,49 @@ def _generate_router_plan_streaming_internal(
803
  yield "", {}, f"❌ Invalid model choice: {model_choice}. Available: {list(MODELS.keys())}", ""
804
  return
805
 
 
 
806
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
807
  prompt = build_router_prompt(
808
  user_task=user_task,
809
- context=context,
810
  acceptance=acceptance,
811
  extra_guidance=extra_guidance,
812
  difficulty=difficulty,
@@ -844,6 +987,14 @@ def _generate_router_plan_streaming_internal(
844
 
845
  prev_text_len = 0
846
  for request_output in stream:
 
 
 
 
 
 
 
 
847
  if not request_output.outputs:
848
  continue
849
 
@@ -909,6 +1060,7 @@ def _generate_router_plan_streaming_internal(
909
  "streamer": streamer,
910
  "eos_token_id": tokenizer.eos_token_id,
911
  "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
 
912
  }
913
 
914
  generation_error = None
@@ -937,6 +1089,9 @@ def _generate_router_plan_streaming_internal(
937
 
938
  try:
939
  for new_text in streamer:
 
 
 
940
  if generation_error:
941
  raise generation_error
942
 
@@ -988,7 +1143,9 @@ def _generate_router_plan_streaming_internal(
988
  completion = trim_at_stop_sequences(completion.strip())[0]
989
  print(f"[DEBUG] Final completion length: {len(completion)}")
990
 
991
- if not completion:
 
 
992
  print("[DEBUG] WARNING: Completion is empty - model may not have generated output")
993
  validation_msg = "⚠️ Model generated empty output. Check GPU allocation and model loading."
994
  elif parsed_plan is None:
@@ -1033,11 +1190,19 @@ def _make_gpu_wrapper(duration: int):
1033
  temperature: float,
1034
  top_p: float,
1035
  gpu_duration: int,
 
 
 
 
1036
  ):
1037
  yield from _generate_router_plan_streaming_internal(
1038
  user_task, context, acceptance, extra_guidance,
1039
  difficulty, tags, model_choice, max_new_tokens,
1040
- temperature, top_p, duration
 
 
 
 
1041
  )
1042
  return wrapper
1043
 
@@ -1058,6 +1223,10 @@ def generate_router_plan_streaming(
1058
  temperature: float,
1059
  top_p: float,
1060
  gpu_duration: int = 600,
 
 
 
 
1061
  ):
1062
  """
1063
  Generate router plan with streaming output.
@@ -1073,7 +1242,11 @@ def generate_router_plan_streaming(
1073
  yield from wrapper(
1074
  user_task, context, acceptance, extra_guidance,
1075
  difficulty, tags, model_choice, max_new_tokens,
1076
- temperature, top_p, rounded_duration
 
 
 
 
1077
  )
1078
 
1079
 
@@ -1081,8 +1254,18 @@ def clear_outputs():
1081
  return "", {}, "Awaiting generation.", ""
1082
 
1083
 
 
 
 
 
 
1084
  def build_ui():
1085
  description = "Use the CourseGPT-Pro router checkpoints (Gemma3/Qwen3) hosted on ZeroGPU to generate structured routing plans."
 
 
 
 
 
1086
  with gr.Blocks(theme=gr.themes.Soft(), css="""
1087
  textarea { font-family: 'JetBrains Mono', 'Fira Code', monospace; }
1088
  .status-ok { color: #0d9488; font-weight: 600; }
@@ -1136,11 +1319,54 @@ def build_ui():
1136
  max_new_tokens = gr.Slider(256, 20000, value=16000, step=32, label="Max New Tokens")
1137
  temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
1138
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
1139
- gpu_duration = gr.Slider(60, 1800, value=600, step=60, label="GPU Duration (seconds)", info="Maximum GPU time allocation for this request")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1140
 
1141
  with gr.Row():
1142
  generate_btn = gr.Button("Generate Router Plan", variant="primary", scale=1)
1143
  clear_btn = gr.Button("Clear", variant="secondary", scale=1)
 
1144
 
1145
  with gr.Row():
1146
  raw_output = gr.Textbox(label="Raw Model Output", lines=12)
@@ -1162,6 +1388,10 @@ def build_ui():
1162
  temperature,
1163
  top_p,
1164
  gpu_duration,
 
 
 
 
1165
  ],
1166
  outputs=[raw_output, plan_json, validation_msg, prompt_view],
1167
  show_progress="full",
@@ -1174,6 +1404,27 @@ def build_ui():
1174
  api_name="/clear_outputs",
1175
  )
1176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1177
  return demo
1178
 
1179
 
 
1
  from __future__ import annotations
2
 
3
  import json
4
+ import math
5
  import os
6
  import re
7
+ import threading
8
+ from itertools import islice
9
  from typing import Any, Dict, List, Tuple, Optional
10
 
11
  import gradio as gr
12
  import spaces
13
  import torch
14
+ from transformers import (
15
+ AutoTokenizer,
16
+ TextIteratorStreamer,
17
+ pipeline,
18
+ StoppingCriteria,
19
+ StoppingCriteriaList,
20
+ )
21
  from threading import Thread
22
  from concurrent.futures import ThreadPoolExecutor
23
 
 
27
  except ImportError: # pragma: no cover
28
  HF_HUB_AVAILABLE = False
29
 
30
+ try:
31
+ from ddgs import DDGS
32
+
33
+ DDGS_AVAILABLE = True
34
+ except ImportError:
35
+ DDGS_AVAILABLE = False
36
+
37
  # Enable optimizations
38
  torch.backends.cuda.matmul.allow_tf32 = True
39
 
 
63
  SamplingParams = None
64
  print("Warning: vLLM not available, falling back to Transformers")
65
 
66
+ cancel_event = threading.Event()
67
+
68
  # Optional flag to disable vLLM (defaults to true on MIG due to device detection instability)
69
  DISABLE_VLLM = os.environ.get("DISABLE_VLLM", "1" if MIG_VISIBLE else "0") == "1"
70
 
 
113
  return None
114
 
115
 
116
+ def _retrieve_search_results(query: str, max_results: int, max_chars: int) -> List[str]:
117
+ if not DDGS_AVAILABLE:
118
+ return []
119
+ results: List[str] = []
120
+ try:
121
+ with DDGS() as ddgs:
122
+ for idx, item in enumerate(
123
+ islice(
124
+ ddgs.text(
125
+ query,
126
+ region="wt-wt",
127
+ safesearch="moderate",
128
+ timelimit="y",
129
+ ),
130
+ max_results,
131
+ )
132
+ ):
133
+ title = (item.get("title") or "Untitled").strip()
134
+ body = (item.get("body") or "").strip()
135
+ url = (item.get("href") or "").strip()
136
+ snippet = body[: max_chars].replace("\n", " ")
137
+ formatted = f"[{idx+1}] {title} — {snippet}"
138
+ if url:
139
+ formatted += f" ({url})"
140
+ results.append(formatted)
141
+ except Exception as exc: # pragma: no cover
142
+ print(f"[DEBUG] DDG search failed: {exc}")
143
+ return results
144
+
145
+
146
+ class CancelStoppingCriteria(StoppingCriteria):
147
+ def __call__(self, input_ids, scores, **kwargs) -> bool:
148
+ return cancel_event.is_set()
149
+
150
+
151
+ def estimate_gpu_seconds(
152
+ model_name: str,
153
+ max_new_tokens: int,
154
+ enable_search: bool,
155
+ ) -> float:
156
+ params_b = MODELS.get(model_name, {}).get("params_b", 4.0)
157
+ base = 12.0 + params_b * 3.0
158
+ tokens_per_sec = max(40.0, 320.0 / (1.0 + params_b / 6.0))
159
+ generation_time = max_new_tokens / tokens_per_sec
160
+ search_time = 8.0 if enable_search else 0.0
161
+ return base + generation_time + search_time
162
+
163
+
164
+ def format_gpu_estimate_message(
165
+ model_name: str,
166
+ max_new_tokens: int,
167
+ enable_search: bool,
168
+ ) -> Tuple[str, int]:
169
+ est_seconds = estimate_gpu_seconds(model_name, max_new_tokens, enable_search)
170
+ rounded = int(math.ceil(est_seconds))
171
+ recommended = int(math.ceil(max(60, rounded) / 60.0) * 60)
172
+ recommended = max(60, min(1800, recommended))
173
+ model_size = MODELS.get(model_name, {}).get("params_b", 4.0)
174
+ message = (
175
+ f"⏱️ **Estimated GPU Time:** ~{rounded} seconds\n\n"
176
+ f"📊 **Model Size:** {model_size:.1f}B parameters\n"
177
+ f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}\n"
178
+ f"✅ **Suggested GPU Duration slider:** {recommended} seconds"
179
+ )
180
+ return message, recommended
181
+
182
+
183
+ def update_gpu_controls(
184
+ model_name: str,
185
+ max_new_tokens: int,
186
+ enable_search: bool,
187
+ current_duration: int,
188
+ ):
189
+ message, recommended = format_gpu_estimate_message(
190
+ model_name,
191
+ max_new_tokens,
192
+ enable_search,
193
+ )
194
+ updated_value = current_duration if current_duration >= recommended else recommended
195
+ return message, gr.update(value=updated_value)
196
+
197
+
198
  def _start_prefetch_workers(model_names: list[str]):
199
  global PREFETCH_EXECUTOR
200
  if PREFETCH_DISABLED or not HF_HUB_AVAILABLE:
 
893
  temperature: float,
894
  top_p: float,
895
  gpu_duration: int,
896
+ enable_search: bool,
897
+ search_max_results: int,
898
+ search_max_chars: int,
899
+ search_timeout: float,
900
  ):
901
  """Internal generator function for streaming token output."""
902
  if not user_task.strip():
 
907
  yield "", {}, f"❌ Invalid model choice: {model_choice}. Available: {list(MODELS.keys())}", ""
908
  return
909
 
910
+ cancel_event.clear()
911
+ cancelled = False
912
  try:
913
+ search_snippets: List[str] = []
914
+ if enable_search and DDGS_AVAILABLE and user_task.strip():
915
+ search_snippets_holder: List[str] = []
916
+ search_error: Optional[Exception] = None
917
+
918
+ def _fetch_search():
919
+ nonlocal search_error
920
+ try:
921
+ results = _retrieve_search_results(
922
+ user_task,
923
+ max(1, int(search_max_results)),
924
+ max(30, int(search_max_chars)),
925
+ )
926
+ search_snippets_holder.extend(results)
927
+ except Exception as exc: # pragma: no cover
928
+ search_error = exc
929
+
930
+ search_thread = Thread(target=_fetch_search, daemon=True)
931
+ search_thread.start()
932
+ search_thread.join(timeout=float(max(0.5, search_timeout)))
933
+ if search_thread.is_alive():
934
+ print("[DEBUG] Search thread timed out; continuing without results.")
935
+ if search_error:
936
+ print(f"[DEBUG] Search error: {search_error}")
937
+ search_snippets = search_snippets_holder
938
+
939
+ context_for_prompt = context
940
+ if search_snippets:
941
+ search_block = "\n".join(f"- {snippet}" for snippet in search_snippets)
942
+ addendum = (
943
+ "\n\n# Web Search Findings\n"
944
+ "Use the following snippets as supplementary evidence. "
945
+ "Cite them as needed in the generated plan.\n"
946
+ f"{search_block}"
947
+ )
948
+ context_for_prompt = (context_for_prompt or "").rstrip() + addendum
949
+
950
  prompt = build_router_prompt(
951
  user_task=user_task,
952
+ context=context_for_prompt,
953
  acceptance=acceptance,
954
  extra_guidance=extra_guidance,
955
  difficulty=difficulty,
 
987
 
988
  prev_text_len = 0
989
  for request_output in stream:
990
+ if cancel_event.is_set():
991
+ cancelled = True
992
+ try:
993
+ if hasattr(generator, "abort_request"):
994
+ generator.abort_request(request_output.request_id)
995
+ except Exception:
996
+ pass
997
+ break
998
  if not request_output.outputs:
999
  continue
1000
 
 
1060
  "streamer": streamer,
1061
  "eos_token_id": tokenizer.eos_token_id,
1062
  "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
1063
+ "stopping_criteria": StoppingCriteriaList([CancelStoppingCriteria()]),
1064
  }
1065
 
1066
  generation_error = None
 
1089
 
1090
  try:
1091
  for new_text in streamer:
1092
+ if cancel_event.is_set():
1093
+ cancelled = True
1094
+ break
1095
  if generation_error:
1096
  raise generation_error
1097
 
 
1143
  completion = trim_at_stop_sequences(completion.strip())[0]
1144
  print(f"[DEBUG] Final completion length: {len(completion)}")
1145
 
1146
+ if cancelled:
1147
+ validation_msg = "⏹️ Generation cancelled by user."
1148
+ elif not completion:
1149
  print("[DEBUG] WARNING: Completion is empty - model may not have generated output")
1150
  validation_msg = "⚠️ Model generated empty output. Check GPU allocation and model loading."
1151
  elif parsed_plan is None:
 
1190
  temperature: float,
1191
  top_p: float,
1192
  gpu_duration: int,
1193
+ enable_search: bool,
1194
+ search_max_results: int,
1195
+ search_max_chars: int,
1196
+ search_timeout: float,
1197
  ):
1198
  yield from _generate_router_plan_streaming_internal(
1199
  user_task, context, acceptance, extra_guidance,
1200
  difficulty, tags, model_choice, max_new_tokens,
1201
+ temperature, top_p, duration,
1202
+ enable_search,
1203
+ search_max_results,
1204
+ search_max_chars,
1205
+ search_timeout,
1206
  )
1207
  return wrapper
1208
 
 
1223
  temperature: float,
1224
  top_p: float,
1225
  gpu_duration: int = 600,
1226
+ enable_search: bool = False,
1227
+ search_max_results: int = 4,
1228
+ search_max_chars: int = 120,
1229
+ search_timeout: float = 5.0,
1230
  ):
1231
  """
1232
  Generate router plan with streaming output.
 
1242
  yield from wrapper(
1243
  user_task, context, acceptance, extra_guidance,
1244
  difficulty, tags, model_choice, max_new_tokens,
1245
+ temperature, top_p, rounded_duration,
1246
+ enable_search,
1247
+ int(search_max_results),
1248
+ int(search_max_chars),
1249
+ float(search_timeout),
1250
  )
1251
 
1252
 
 
1254
  return "", {}, "Awaiting generation.", ""
1255
 
1256
 
1257
+ def cancel_generation():
1258
+ cancel_event.set()
1259
+ return "⏹️ Cancel request sent. Finishing current step..."
1260
+
1261
+
1262
  def build_ui():
1263
  description = "Use the CourseGPT-Pro router checkpoints (Gemma3/Qwen3) hosted on ZeroGPU to generate structured routing plans."
1264
+ initial_estimate_text, initial_recommended_duration = format_gpu_estimate_message(
1265
+ DEFAULT_MODEL,
1266
+ 16000,
1267
+ False,
1268
+ )
1269
  with gr.Blocks(theme=gr.themes.Soft(), css="""
1270
  textarea { font-family: 'JetBrains Mono', 'Fira Code', monospace; }
1271
  .status-ok { color: #0d9488; font-weight: 600; }
 
1319
  max_new_tokens = gr.Slider(256, 20000, value=16000, step=32, label="Max New Tokens")
1320
  temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
1321
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
1322
+ enable_search = gr.Checkbox(
1323
+ label="Enable DuckDuckGo Web Search",
1324
+ value=False,
1325
+ interactive=DDGS_AVAILABLE,
1326
+ info="Augment context with live snippets." if DDGS_AVAILABLE else "Install 'ddgs' package to enable search.",
1327
+ )
1328
+ with gr.Accordion("Web Search Settings", open=False, visible=DDGS_AVAILABLE) as search_settings:
1329
+ search_max_results = gr.Slider(
1330
+ minimum=1,
1331
+ maximum=10,
1332
+ value=4,
1333
+ step=1,
1334
+ label="Search Results",
1335
+ interactive=DDGS_AVAILABLE,
1336
+ )
1337
+ search_max_chars = gr.Slider(
1338
+ minimum=50,
1339
+ maximum=400,
1340
+ value=160,
1341
+ step=10,
1342
+ label="Max Characters per Result",
1343
+ interactive=DDGS_AVAILABLE,
1344
+ )
1345
+ search_timeout = gr.Slider(
1346
+ minimum=1.0,
1347
+ maximum=20.0,
1348
+ value=5.0,
1349
+ step=0.5,
1350
+ label="Search Timeout (seconds)",
1351
+ interactive=DDGS_AVAILABLE,
1352
+ )
1353
+ gpu_estimate_display = gr.Markdown(
1354
+ value=initial_estimate_text,
1355
+ elem_classes="status-ok",
1356
+ )
1357
+ gpu_duration = gr.Slider(
1358
+ 60,
1359
+ 1800,
1360
+ value=initial_recommended_duration,
1361
+ step=60,
1362
+ label="GPU Duration (seconds)",
1363
+ info="Maximum GPU time allocation for this request",
1364
+ )
1365
 
1366
  with gr.Row():
1367
  generate_btn = gr.Button("Generate Router Plan", variant="primary", scale=1)
1368
  clear_btn = gr.Button("Clear", variant="secondary", scale=1)
1369
+ cancel_btn = gr.Button("Cancel", variant="stop", scale=1)
1370
 
1371
  with gr.Row():
1372
  raw_output = gr.Textbox(label="Raw Model Output", lines=12)
 
1388
  temperature,
1389
  top_p,
1390
  gpu_duration,
1391
+ enable_search,
1392
+ search_max_results,
1393
+ search_max_chars,
1394
+ search_timeout,
1395
  ],
1396
  outputs=[raw_output, plan_json, validation_msg, prompt_view],
1397
  show_progress="full",
 
1404
  api_name="/clear_outputs",
1405
  )
1406
 
1407
+ cancel_btn.click(
1408
+ fn=cancel_generation,
1409
+ outputs=[validation_msg],
1410
+ )
1411
+
1412
+ model_choice.change(
1413
+ fn=update_gpu_controls,
1414
+ inputs=[model_choice, max_new_tokens, enable_search, gpu_duration],
1415
+ outputs=[gpu_estimate_display, gpu_duration],
1416
+ )
1417
+ max_new_tokens.change(
1418
+ fn=update_gpu_controls,
1419
+ inputs=[model_choice, max_new_tokens, enable_search, gpu_duration],
1420
+ outputs=[gpu_estimate_display, gpu_duration],
1421
+ )
1422
+ enable_search.change(
1423
+ fn=update_gpu_controls,
1424
+ inputs=[model_choice, max_new_tokens, enable_search, gpu_duration],
1425
+ outputs=[gpu_estimate_display, gpu_duration],
1426
+ )
1427
+
1428
  return demo
1429
 
1430