Spaces:
Sleeping
Sleeping
Commit
·
4ce42e8
1
Parent(s):
0e2f6c4
Add GPU estimator, DDG search, and cancel support
Browse files
app.py
CHANGED
|
@@ -1,14 +1,23 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import json
|
|
|
|
| 4 |
import os
|
| 5 |
import re
|
|
|
|
|
|
|
| 6 |
from typing import Any, Dict, List, Tuple, Optional
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
import spaces
|
| 10 |
import torch
|
| 11 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from threading import Thread
|
| 13 |
from concurrent.futures import ThreadPoolExecutor
|
| 14 |
|
|
@@ -18,6 +27,13 @@ try:
|
|
| 18 |
except ImportError: # pragma: no cover
|
| 19 |
HF_HUB_AVAILABLE = False
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
# Enable optimizations
|
| 22 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 23 |
|
|
@@ -47,6 +63,8 @@ except ImportError:
|
|
| 47 |
SamplingParams = None
|
| 48 |
print("Warning: vLLM not available, falling back to Transformers")
|
| 49 |
|
|
|
|
|
|
|
| 50 |
# Optional flag to disable vLLM (defaults to true on MIG due to device detection instability)
|
| 51 |
DISABLE_VLLM = os.environ.get("DISABLE_VLLM", "1" if MIG_VISIBLE else "0") == "1"
|
| 52 |
|
|
@@ -95,6 +113,88 @@ def _ensure_local_repo(repo_id: str) -> Optional[str]:
|
|
| 95 |
return None
|
| 96 |
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
def _start_prefetch_workers(model_names: list[str]):
|
| 99 |
global PREFETCH_EXECUTOR
|
| 100 |
if PREFETCH_DISABLED or not HF_HUB_AVAILABLE:
|
|
@@ -793,6 +893,10 @@ def _generate_router_plan_streaming_internal(
|
|
| 793 |
temperature: float,
|
| 794 |
top_p: float,
|
| 795 |
gpu_duration: int,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 796 |
):
|
| 797 |
"""Internal generator function for streaming token output."""
|
| 798 |
if not user_task.strip():
|
|
@@ -803,10 +907,49 @@ def _generate_router_plan_streaming_internal(
|
|
| 803 |
yield "", {}, f"❌ Invalid model choice: {model_choice}. Available: {list(MODELS.keys())}", ""
|
| 804 |
return
|
| 805 |
|
|
|
|
|
|
|
| 806 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
prompt = build_router_prompt(
|
| 808 |
user_task=user_task,
|
| 809 |
-
context=
|
| 810 |
acceptance=acceptance,
|
| 811 |
extra_guidance=extra_guidance,
|
| 812 |
difficulty=difficulty,
|
|
@@ -844,6 +987,14 @@ def _generate_router_plan_streaming_internal(
|
|
| 844 |
|
| 845 |
prev_text_len = 0
|
| 846 |
for request_output in stream:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
if not request_output.outputs:
|
| 848 |
continue
|
| 849 |
|
|
@@ -909,6 +1060,7 @@ def _generate_router_plan_streaming_internal(
|
|
| 909 |
"streamer": streamer,
|
| 910 |
"eos_token_id": tokenizer.eos_token_id,
|
| 911 |
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
|
|
|
|
| 912 |
}
|
| 913 |
|
| 914 |
generation_error = None
|
|
@@ -937,6 +1089,9 @@ def _generate_router_plan_streaming_internal(
|
|
| 937 |
|
| 938 |
try:
|
| 939 |
for new_text in streamer:
|
|
|
|
|
|
|
|
|
|
| 940 |
if generation_error:
|
| 941 |
raise generation_error
|
| 942 |
|
|
@@ -988,7 +1143,9 @@ def _generate_router_plan_streaming_internal(
|
|
| 988 |
completion = trim_at_stop_sequences(completion.strip())[0]
|
| 989 |
print(f"[DEBUG] Final completion length: {len(completion)}")
|
| 990 |
|
| 991 |
-
if
|
|
|
|
|
|
|
| 992 |
print("[DEBUG] WARNING: Completion is empty - model may not have generated output")
|
| 993 |
validation_msg = "⚠️ Model generated empty output. Check GPU allocation and model loading."
|
| 994 |
elif parsed_plan is None:
|
|
@@ -1033,11 +1190,19 @@ def _make_gpu_wrapper(duration: int):
|
|
| 1033 |
temperature: float,
|
| 1034 |
top_p: float,
|
| 1035 |
gpu_duration: int,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1036 |
):
|
| 1037 |
yield from _generate_router_plan_streaming_internal(
|
| 1038 |
user_task, context, acceptance, extra_guidance,
|
| 1039 |
difficulty, tags, model_choice, max_new_tokens,
|
| 1040 |
-
temperature, top_p, duration
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1041 |
)
|
| 1042 |
return wrapper
|
| 1043 |
|
|
@@ -1058,6 +1223,10 @@ def generate_router_plan_streaming(
|
|
| 1058 |
temperature: float,
|
| 1059 |
top_p: float,
|
| 1060 |
gpu_duration: int = 600,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1061 |
):
|
| 1062 |
"""
|
| 1063 |
Generate router plan with streaming output.
|
|
@@ -1073,7 +1242,11 @@ def generate_router_plan_streaming(
|
|
| 1073 |
yield from wrapper(
|
| 1074 |
user_task, context, acceptance, extra_guidance,
|
| 1075 |
difficulty, tags, model_choice, max_new_tokens,
|
| 1076 |
-
temperature, top_p, rounded_duration
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1077 |
)
|
| 1078 |
|
| 1079 |
|
|
@@ -1081,8 +1254,18 @@ def clear_outputs():
|
|
| 1081 |
return "", {}, "Awaiting generation.", ""
|
| 1082 |
|
| 1083 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1084 |
def build_ui():
|
| 1085 |
description = "Use the CourseGPT-Pro router checkpoints (Gemma3/Qwen3) hosted on ZeroGPU to generate structured routing plans."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1086 |
with gr.Blocks(theme=gr.themes.Soft(), css="""
|
| 1087 |
textarea { font-family: 'JetBrains Mono', 'Fira Code', monospace; }
|
| 1088 |
.status-ok { color: #0d9488; font-weight: 600; }
|
|
@@ -1136,11 +1319,54 @@ def build_ui():
|
|
| 1136 |
max_new_tokens = gr.Slider(256, 20000, value=16000, step=32, label="Max New Tokens")
|
| 1137 |
temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
|
| 1138 |
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
|
| 1139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1140 |
|
| 1141 |
with gr.Row():
|
| 1142 |
generate_btn = gr.Button("Generate Router Plan", variant="primary", scale=1)
|
| 1143 |
clear_btn = gr.Button("Clear", variant="secondary", scale=1)
|
|
|
|
| 1144 |
|
| 1145 |
with gr.Row():
|
| 1146 |
raw_output = gr.Textbox(label="Raw Model Output", lines=12)
|
|
@@ -1162,6 +1388,10 @@ def build_ui():
|
|
| 1162 |
temperature,
|
| 1163 |
top_p,
|
| 1164 |
gpu_duration,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1165 |
],
|
| 1166 |
outputs=[raw_output, plan_json, validation_msg, prompt_view],
|
| 1167 |
show_progress="full",
|
|
@@ -1174,6 +1404,27 @@ def build_ui():
|
|
| 1174 |
api_name="/clear_outputs",
|
| 1175 |
)
|
| 1176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1177 |
return demo
|
| 1178 |
|
| 1179 |
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import json
|
| 4 |
+
import math
|
| 5 |
import os
|
| 6 |
import re
|
| 7 |
+
import threading
|
| 8 |
+
from itertools import islice
|
| 9 |
from typing import Any, Dict, List, Tuple, Optional
|
| 10 |
|
| 11 |
import gradio as gr
|
| 12 |
import spaces
|
| 13 |
import torch
|
| 14 |
+
from transformers import (
|
| 15 |
+
AutoTokenizer,
|
| 16 |
+
TextIteratorStreamer,
|
| 17 |
+
pipeline,
|
| 18 |
+
StoppingCriteria,
|
| 19 |
+
StoppingCriteriaList,
|
| 20 |
+
)
|
| 21 |
from threading import Thread
|
| 22 |
from concurrent.futures import ThreadPoolExecutor
|
| 23 |
|
|
|
|
| 27 |
except ImportError: # pragma: no cover
|
| 28 |
HF_HUB_AVAILABLE = False
|
| 29 |
|
| 30 |
+
try:
|
| 31 |
+
from ddgs import DDGS
|
| 32 |
+
|
| 33 |
+
DDGS_AVAILABLE = True
|
| 34 |
+
except ImportError:
|
| 35 |
+
DDGS_AVAILABLE = False
|
| 36 |
+
|
| 37 |
# Enable optimizations
|
| 38 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 39 |
|
|
|
|
| 63 |
SamplingParams = None
|
| 64 |
print("Warning: vLLM not available, falling back to Transformers")
|
| 65 |
|
| 66 |
+
cancel_event = threading.Event()
|
| 67 |
+
|
| 68 |
# Optional flag to disable vLLM (defaults to true on MIG due to device detection instability)
|
| 69 |
DISABLE_VLLM = os.environ.get("DISABLE_VLLM", "1" if MIG_VISIBLE else "0") == "1"
|
| 70 |
|
|
|
|
| 113 |
return None
|
| 114 |
|
| 115 |
|
| 116 |
+
def _retrieve_search_results(query: str, max_results: int, max_chars: int) -> List[str]:
|
| 117 |
+
if not DDGS_AVAILABLE:
|
| 118 |
+
return []
|
| 119 |
+
results: List[str] = []
|
| 120 |
+
try:
|
| 121 |
+
with DDGS() as ddgs:
|
| 122 |
+
for idx, item in enumerate(
|
| 123 |
+
islice(
|
| 124 |
+
ddgs.text(
|
| 125 |
+
query,
|
| 126 |
+
region="wt-wt",
|
| 127 |
+
safesearch="moderate",
|
| 128 |
+
timelimit="y",
|
| 129 |
+
),
|
| 130 |
+
max_results,
|
| 131 |
+
)
|
| 132 |
+
):
|
| 133 |
+
title = (item.get("title") or "Untitled").strip()
|
| 134 |
+
body = (item.get("body") or "").strip()
|
| 135 |
+
url = (item.get("href") or "").strip()
|
| 136 |
+
snippet = body[: max_chars].replace("\n", " ")
|
| 137 |
+
formatted = f"[{idx+1}] {title} — {snippet}"
|
| 138 |
+
if url:
|
| 139 |
+
formatted += f" ({url})"
|
| 140 |
+
results.append(formatted)
|
| 141 |
+
except Exception as exc: # pragma: no cover
|
| 142 |
+
print(f"[DEBUG] DDG search failed: {exc}")
|
| 143 |
+
return results
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class CancelStoppingCriteria(StoppingCriteria):
|
| 147 |
+
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
| 148 |
+
return cancel_event.is_set()
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def estimate_gpu_seconds(
|
| 152 |
+
model_name: str,
|
| 153 |
+
max_new_tokens: int,
|
| 154 |
+
enable_search: bool,
|
| 155 |
+
) -> float:
|
| 156 |
+
params_b = MODELS.get(model_name, {}).get("params_b", 4.0)
|
| 157 |
+
base = 12.0 + params_b * 3.0
|
| 158 |
+
tokens_per_sec = max(40.0, 320.0 / (1.0 + params_b / 6.0))
|
| 159 |
+
generation_time = max_new_tokens / tokens_per_sec
|
| 160 |
+
search_time = 8.0 if enable_search else 0.0
|
| 161 |
+
return base + generation_time + search_time
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def format_gpu_estimate_message(
|
| 165 |
+
model_name: str,
|
| 166 |
+
max_new_tokens: int,
|
| 167 |
+
enable_search: bool,
|
| 168 |
+
) -> Tuple[str, int]:
|
| 169 |
+
est_seconds = estimate_gpu_seconds(model_name, max_new_tokens, enable_search)
|
| 170 |
+
rounded = int(math.ceil(est_seconds))
|
| 171 |
+
recommended = int(math.ceil(max(60, rounded) / 60.0) * 60)
|
| 172 |
+
recommended = max(60, min(1800, recommended))
|
| 173 |
+
model_size = MODELS.get(model_name, {}).get("params_b", 4.0)
|
| 174 |
+
message = (
|
| 175 |
+
f"⏱️ **Estimated GPU Time:** ~{rounded} seconds\n\n"
|
| 176 |
+
f"📊 **Model Size:** {model_size:.1f}B parameters\n"
|
| 177 |
+
f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}\n"
|
| 178 |
+
f"✅ **Suggested GPU Duration slider:** {recommended} seconds"
|
| 179 |
+
)
|
| 180 |
+
return message, recommended
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def update_gpu_controls(
|
| 184 |
+
model_name: str,
|
| 185 |
+
max_new_tokens: int,
|
| 186 |
+
enable_search: bool,
|
| 187 |
+
current_duration: int,
|
| 188 |
+
):
|
| 189 |
+
message, recommended = format_gpu_estimate_message(
|
| 190 |
+
model_name,
|
| 191 |
+
max_new_tokens,
|
| 192 |
+
enable_search,
|
| 193 |
+
)
|
| 194 |
+
updated_value = current_duration if current_duration >= recommended else recommended
|
| 195 |
+
return message, gr.update(value=updated_value)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
def _start_prefetch_workers(model_names: list[str]):
|
| 199 |
global PREFETCH_EXECUTOR
|
| 200 |
if PREFETCH_DISABLED or not HF_HUB_AVAILABLE:
|
|
|
|
| 893 |
temperature: float,
|
| 894 |
top_p: float,
|
| 895 |
gpu_duration: int,
|
| 896 |
+
enable_search: bool,
|
| 897 |
+
search_max_results: int,
|
| 898 |
+
search_max_chars: int,
|
| 899 |
+
search_timeout: float,
|
| 900 |
):
|
| 901 |
"""Internal generator function for streaming token output."""
|
| 902 |
if not user_task.strip():
|
|
|
|
| 907 |
yield "", {}, f"❌ Invalid model choice: {model_choice}. Available: {list(MODELS.keys())}", ""
|
| 908 |
return
|
| 909 |
|
| 910 |
+
cancel_event.clear()
|
| 911 |
+
cancelled = False
|
| 912 |
try:
|
| 913 |
+
search_snippets: List[str] = []
|
| 914 |
+
if enable_search and DDGS_AVAILABLE and user_task.strip():
|
| 915 |
+
search_snippets_holder: List[str] = []
|
| 916 |
+
search_error: Optional[Exception] = None
|
| 917 |
+
|
| 918 |
+
def _fetch_search():
|
| 919 |
+
nonlocal search_error
|
| 920 |
+
try:
|
| 921 |
+
results = _retrieve_search_results(
|
| 922 |
+
user_task,
|
| 923 |
+
max(1, int(search_max_results)),
|
| 924 |
+
max(30, int(search_max_chars)),
|
| 925 |
+
)
|
| 926 |
+
search_snippets_holder.extend(results)
|
| 927 |
+
except Exception as exc: # pragma: no cover
|
| 928 |
+
search_error = exc
|
| 929 |
+
|
| 930 |
+
search_thread = Thread(target=_fetch_search, daemon=True)
|
| 931 |
+
search_thread.start()
|
| 932 |
+
search_thread.join(timeout=float(max(0.5, search_timeout)))
|
| 933 |
+
if search_thread.is_alive():
|
| 934 |
+
print("[DEBUG] Search thread timed out; continuing without results.")
|
| 935 |
+
if search_error:
|
| 936 |
+
print(f"[DEBUG] Search error: {search_error}")
|
| 937 |
+
search_snippets = search_snippets_holder
|
| 938 |
+
|
| 939 |
+
context_for_prompt = context
|
| 940 |
+
if search_snippets:
|
| 941 |
+
search_block = "\n".join(f"- {snippet}" for snippet in search_snippets)
|
| 942 |
+
addendum = (
|
| 943 |
+
"\n\n# Web Search Findings\n"
|
| 944 |
+
"Use the following snippets as supplementary evidence. "
|
| 945 |
+
"Cite them as needed in the generated plan.\n"
|
| 946 |
+
f"{search_block}"
|
| 947 |
+
)
|
| 948 |
+
context_for_prompt = (context_for_prompt or "").rstrip() + addendum
|
| 949 |
+
|
| 950 |
prompt = build_router_prompt(
|
| 951 |
user_task=user_task,
|
| 952 |
+
context=context_for_prompt,
|
| 953 |
acceptance=acceptance,
|
| 954 |
extra_guidance=extra_guidance,
|
| 955 |
difficulty=difficulty,
|
|
|
|
| 987 |
|
| 988 |
prev_text_len = 0
|
| 989 |
for request_output in stream:
|
| 990 |
+
if cancel_event.is_set():
|
| 991 |
+
cancelled = True
|
| 992 |
+
try:
|
| 993 |
+
if hasattr(generator, "abort_request"):
|
| 994 |
+
generator.abort_request(request_output.request_id)
|
| 995 |
+
except Exception:
|
| 996 |
+
pass
|
| 997 |
+
break
|
| 998 |
if not request_output.outputs:
|
| 999 |
continue
|
| 1000 |
|
|
|
|
| 1060 |
"streamer": streamer,
|
| 1061 |
"eos_token_id": tokenizer.eos_token_id,
|
| 1062 |
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 1063 |
+
"stopping_criteria": StoppingCriteriaList([CancelStoppingCriteria()]),
|
| 1064 |
}
|
| 1065 |
|
| 1066 |
generation_error = None
|
|
|
|
| 1089 |
|
| 1090 |
try:
|
| 1091 |
for new_text in streamer:
|
| 1092 |
+
if cancel_event.is_set():
|
| 1093 |
+
cancelled = True
|
| 1094 |
+
break
|
| 1095 |
if generation_error:
|
| 1096 |
raise generation_error
|
| 1097 |
|
|
|
|
| 1143 |
completion = trim_at_stop_sequences(completion.strip())[0]
|
| 1144 |
print(f"[DEBUG] Final completion length: {len(completion)}")
|
| 1145 |
|
| 1146 |
+
if cancelled:
|
| 1147 |
+
validation_msg = "⏹️ Generation cancelled by user."
|
| 1148 |
+
elif not completion:
|
| 1149 |
print("[DEBUG] WARNING: Completion is empty - model may not have generated output")
|
| 1150 |
validation_msg = "⚠️ Model generated empty output. Check GPU allocation and model loading."
|
| 1151 |
elif parsed_plan is None:
|
|
|
|
| 1190 |
temperature: float,
|
| 1191 |
top_p: float,
|
| 1192 |
gpu_duration: int,
|
| 1193 |
+
enable_search: bool,
|
| 1194 |
+
search_max_results: int,
|
| 1195 |
+
search_max_chars: int,
|
| 1196 |
+
search_timeout: float,
|
| 1197 |
):
|
| 1198 |
yield from _generate_router_plan_streaming_internal(
|
| 1199 |
user_task, context, acceptance, extra_guidance,
|
| 1200 |
difficulty, tags, model_choice, max_new_tokens,
|
| 1201 |
+
temperature, top_p, duration,
|
| 1202 |
+
enable_search,
|
| 1203 |
+
search_max_results,
|
| 1204 |
+
search_max_chars,
|
| 1205 |
+
search_timeout,
|
| 1206 |
)
|
| 1207 |
return wrapper
|
| 1208 |
|
|
|
|
| 1223 |
temperature: float,
|
| 1224 |
top_p: float,
|
| 1225 |
gpu_duration: int = 600,
|
| 1226 |
+
enable_search: bool = False,
|
| 1227 |
+
search_max_results: int = 4,
|
| 1228 |
+
search_max_chars: int = 120,
|
| 1229 |
+
search_timeout: float = 5.0,
|
| 1230 |
):
|
| 1231 |
"""
|
| 1232 |
Generate router plan with streaming output.
|
|
|
|
| 1242 |
yield from wrapper(
|
| 1243 |
user_task, context, acceptance, extra_guidance,
|
| 1244 |
difficulty, tags, model_choice, max_new_tokens,
|
| 1245 |
+
temperature, top_p, rounded_duration,
|
| 1246 |
+
enable_search,
|
| 1247 |
+
int(search_max_results),
|
| 1248 |
+
int(search_max_chars),
|
| 1249 |
+
float(search_timeout),
|
| 1250 |
)
|
| 1251 |
|
| 1252 |
|
|
|
|
| 1254 |
return "", {}, "Awaiting generation.", ""
|
| 1255 |
|
| 1256 |
|
| 1257 |
+
def cancel_generation():
|
| 1258 |
+
cancel_event.set()
|
| 1259 |
+
return "⏹️ Cancel request sent. Finishing current step..."
|
| 1260 |
+
|
| 1261 |
+
|
| 1262 |
def build_ui():
|
| 1263 |
description = "Use the CourseGPT-Pro router checkpoints (Gemma3/Qwen3) hosted on ZeroGPU to generate structured routing plans."
|
| 1264 |
+
initial_estimate_text, initial_recommended_duration = format_gpu_estimate_message(
|
| 1265 |
+
DEFAULT_MODEL,
|
| 1266 |
+
16000,
|
| 1267 |
+
False,
|
| 1268 |
+
)
|
| 1269 |
with gr.Blocks(theme=gr.themes.Soft(), css="""
|
| 1270 |
textarea { font-family: 'JetBrains Mono', 'Fira Code', monospace; }
|
| 1271 |
.status-ok { color: #0d9488; font-weight: 600; }
|
|
|
|
| 1319 |
max_new_tokens = gr.Slider(256, 20000, value=16000, step=32, label="Max New Tokens")
|
| 1320 |
temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
|
| 1321 |
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
|
| 1322 |
+
enable_search = gr.Checkbox(
|
| 1323 |
+
label="Enable DuckDuckGo Web Search",
|
| 1324 |
+
value=False,
|
| 1325 |
+
interactive=DDGS_AVAILABLE,
|
| 1326 |
+
info="Augment context with live snippets." if DDGS_AVAILABLE else "Install 'ddgs' package to enable search.",
|
| 1327 |
+
)
|
| 1328 |
+
with gr.Accordion("Web Search Settings", open=False, visible=DDGS_AVAILABLE) as search_settings:
|
| 1329 |
+
search_max_results = gr.Slider(
|
| 1330 |
+
minimum=1,
|
| 1331 |
+
maximum=10,
|
| 1332 |
+
value=4,
|
| 1333 |
+
step=1,
|
| 1334 |
+
label="Search Results",
|
| 1335 |
+
interactive=DDGS_AVAILABLE,
|
| 1336 |
+
)
|
| 1337 |
+
search_max_chars = gr.Slider(
|
| 1338 |
+
minimum=50,
|
| 1339 |
+
maximum=400,
|
| 1340 |
+
value=160,
|
| 1341 |
+
step=10,
|
| 1342 |
+
label="Max Characters per Result",
|
| 1343 |
+
interactive=DDGS_AVAILABLE,
|
| 1344 |
+
)
|
| 1345 |
+
search_timeout = gr.Slider(
|
| 1346 |
+
minimum=1.0,
|
| 1347 |
+
maximum=20.0,
|
| 1348 |
+
value=5.0,
|
| 1349 |
+
step=0.5,
|
| 1350 |
+
label="Search Timeout (seconds)",
|
| 1351 |
+
interactive=DDGS_AVAILABLE,
|
| 1352 |
+
)
|
| 1353 |
+
gpu_estimate_display = gr.Markdown(
|
| 1354 |
+
value=initial_estimate_text,
|
| 1355 |
+
elem_classes="status-ok",
|
| 1356 |
+
)
|
| 1357 |
+
gpu_duration = gr.Slider(
|
| 1358 |
+
60,
|
| 1359 |
+
1800,
|
| 1360 |
+
value=initial_recommended_duration,
|
| 1361 |
+
step=60,
|
| 1362 |
+
label="GPU Duration (seconds)",
|
| 1363 |
+
info="Maximum GPU time allocation for this request",
|
| 1364 |
+
)
|
| 1365 |
|
| 1366 |
with gr.Row():
|
| 1367 |
generate_btn = gr.Button("Generate Router Plan", variant="primary", scale=1)
|
| 1368 |
clear_btn = gr.Button("Clear", variant="secondary", scale=1)
|
| 1369 |
+
cancel_btn = gr.Button("Cancel", variant="stop", scale=1)
|
| 1370 |
|
| 1371 |
with gr.Row():
|
| 1372 |
raw_output = gr.Textbox(label="Raw Model Output", lines=12)
|
|
|
|
| 1388 |
temperature,
|
| 1389 |
top_p,
|
| 1390 |
gpu_duration,
|
| 1391 |
+
enable_search,
|
| 1392 |
+
search_max_results,
|
| 1393 |
+
search_max_chars,
|
| 1394 |
+
search_timeout,
|
| 1395 |
],
|
| 1396 |
outputs=[raw_output, plan_json, validation_msg, prompt_view],
|
| 1397 |
show_progress="full",
|
|
|
|
| 1404 |
api_name="/clear_outputs",
|
| 1405 |
)
|
| 1406 |
|
| 1407 |
+
cancel_btn.click(
|
| 1408 |
+
fn=cancel_generation,
|
| 1409 |
+
outputs=[validation_msg],
|
| 1410 |
+
)
|
| 1411 |
+
|
| 1412 |
+
model_choice.change(
|
| 1413 |
+
fn=update_gpu_controls,
|
| 1414 |
+
inputs=[model_choice, max_new_tokens, enable_search, gpu_duration],
|
| 1415 |
+
outputs=[gpu_estimate_display, gpu_duration],
|
| 1416 |
+
)
|
| 1417 |
+
max_new_tokens.change(
|
| 1418 |
+
fn=update_gpu_controls,
|
| 1419 |
+
inputs=[model_choice, max_new_tokens, enable_search, gpu_duration],
|
| 1420 |
+
outputs=[gpu_estimate_display, gpu_duration],
|
| 1421 |
+
)
|
| 1422 |
+
enable_search.change(
|
| 1423 |
+
fn=update_gpu_controls,
|
| 1424 |
+
inputs=[model_choice, max_new_tokens, enable_search, gpu_duration],
|
| 1425 |
+
outputs=[gpu_estimate_display, gpu_duration],
|
| 1426 |
+
)
|
| 1427 |
+
|
| 1428 |
return demo
|
| 1429 |
|
| 1430 |
|