Crystalcareai commited on
Commit
8e2b241
·
verified ·
1 Parent(s): 737b2b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +340 -52
app.py CHANGED
@@ -1,64 +1,352 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
27
 
28
- response = ""
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
 
 
 
 
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ from typing import Dict, List, Optional, Generator, AsyncGenerator
4
+ from dataclasses import dataclass
5
+ import httpx
6
+ import json
7
+ import asyncio
8
+ import openai
9
+ import os
10
 
11
+ arcee_api_key = os.environ.get("arcee_api_key")
12
+ openrouter_api_key = os.environ.get("openrouter_api_key")
 
 
13
 
14
+ @dataclass
15
+ class ModelConfig:
16
+ name: str
17
+ base_url: str
18
+ api_key: str
19
+
20
+ MODEL_CONFIGS = {
21
+ 1: ModelConfig(
22
+ name="virtuoso-small",
23
+ base_url="https://models.arcee.ai/v1/chat/completions",
24
+ api_key=arcee_api_key
25
+ ),
26
+ 2: ModelConfig(
27
+ name="virtuoso-medium",
28
+ base_url="https://models.arcee.ai/v1/chat/completions",
29
+ api_key=arcee_api_key
30
+ ),
31
+ 3: ModelConfig(
32
+ name="virtuoso-large",
33
+ base_url="https://models.arcee.ai/v1/chat/completions",
34
+ api_key=arcee_api_key
35
+ ),
36
+ 4: ModelConfig(
37
+ name="anthropic/claude-3.5-sonnet",
38
+ base_url="https://openrouter.ai/api/v1/chat/completions",
39
+ api_key=openrouter_api_key
40
+ )
41
+ }
42
 
43
+ class ModelUsageStats:
44
+ def __init__(self):
45
+ self.usage_counts = {i: 0 for i in range(1, 5)}
46
+ self.total_queries = 0
47
+
48
+ def update(self, complexity: int):
49
+ self.usage_counts[complexity] += 1
50
+ self.total_queries += 1
51
+
52
+ def get_stats(self) -> str:
53
+ if self.total_queries == 0:
54
+ return "No queries processed yet."
55
+
56
+ model_names = {
57
+ 1: "virtuoso-small",
58
+ 2: "virtuoso-medium",
59
+ 3: "virtuoso-large",
60
+ 4: "claude-3.5-sonnet"
61
+ }
62
+
63
+ stats = []
64
+ for complexity, count in self.usage_counts.items():
65
+ percentage = (count / self.total_queries) * 100
66
+ stats.append(f"{model_names[complexity]}: {count} uses ({percentage:.1f}%)")
67
+ return "\n".join(stats)
68
 
69
+ stats = ModelUsageStats()
70
 
71
+ async def get_complexity(prompt: str) -> int:
72
+ try:
73
+ async with httpx.AsyncClient(http2=True) as client:
74
+ response = await client.post(
75
+ "http://185.216.20.86:8000/complexity",
76
+ headers={"Content-Type": "application/json"},
77
+ json={"prompt": prompt},
78
+ timeout=10
79
+ )
80
+ response.raise_for_status()
81
+ return response.json()["complexity"]
82
+ except Exception as e:
83
+ print(f"Error getting complexity: {e}")
84
+ return 3 # Default to medium complexity on error
85
 
86
+ async def get_model_response(message: str, history: List[Dict[str, str]], complexity: int) -> AsyncGenerator[str, None]:
87
+ model_config = MODEL_CONFIGS[complexity]
88
+
89
+ headers = {
90
+ "Content-Type": "application/json"
91
+ }
92
+
93
+ if "openrouter.ai" in model_config.base_url:
94
+ headers.update({
95
+ "HTTP-Referer": "https://github.com/lucataco/gradio-router",
96
+ "X-Title": "Gradio Router",
97
+ "Authorization": f"Bearer {model_config.api_key}"
98
+ })
99
+ elif "arcee.ai" in model_config.base_url:
100
+ headers.update({
101
+ "Authorization": f"Bearer {model_config.api_key}"
102
+ })
103
+
104
+ try:
105
+ collected_chunks = []
106
+ # For Arcee.ai models, use direct API call with HTTP/2
107
+ if "arcee.ai" in model_config.base_url:
108
+ # Start with system message
109
+ messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
110
+
111
+ # Process history to ensure it's in the correct format
112
+ for msg in history:
113
+ # Remove any model info or stats from previous responses
114
+ content = msg["content"]
115
+ if isinstance(content, str):
116
+ content = content.split("\n\n<div")[0]
117
+ messages.append({
118
+ "role": msg["role"],
119
+ "content": content
120
+ })
121
+
122
+ # Add the current message
123
+ messages.append({"role": "user", "content": message})
124
+
125
+ async with httpx.AsyncClient(http2=True) as client:
126
+ async with client.stream(
127
+ "POST",
128
+ model_config.base_url,
129
+ headers=headers,
130
+ json={
131
+ "model": model_config.name,
132
+ "messages": messages,
133
+ "temperature": 0.7,
134
+ "stream": True
135
+ },
136
+ timeout=30.0
137
+ ) as response:
138
+ response.raise_for_status()
139
+ buffer = []
140
+ async for line in response.aiter_lines():
141
+ if line.startswith("data: "):
142
+ try:
143
+ json_response = json.loads(line.replace("data: ", ""))
144
+ if json_response.get('choices') and json_response['choices'][0].get('delta', {}).get('content'):
145
+ buffer.append(json_response['choices'][0]['delta']['content'])
146
+ if len(buffer) >= 10 or any(c in '.,!?\n' for c in buffer[-1]):
147
+ collected_chunks.extend(buffer)
148
+ yield "".join(collected_chunks)
149
+ buffer = []
150
+ except json.JSONDecodeError:
151
+ continue
152
+ if buffer: # Yield any remaining content
153
+ collected_chunks.extend(buffer)
154
+ yield "".join(collected_chunks)
155
+
156
+ # For OpenRouter models, use OpenAI client
157
+ else:
158
+ client = openai.AsyncOpenAI(
159
+ base_url=model_config.base_url,
160
+ api_key=model_config.api_key,
161
+ default_headers=headers,
162
+ http_client=httpx.AsyncClient(http2=True)
163
+ )
164
+
165
+ # Process history similarly for OpenRouter
166
+ messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
167
+ for msg in history:
168
+ content = msg["content"]
169
+ if isinstance(content, str):
170
+ content = content.split("\n\n<div")[0]
171
+ messages.append({
172
+ "role": msg["role"],
173
+ "content": content
174
+ })
175
+ messages.append({"role": "user", "content": message})
176
+
177
+ response = await client.chat.completions.create(
178
+ model=model_config.name,
179
+ messages=messages,
180
+ temperature=0.7,
181
+ stream=True
182
+ )
183
+
184
+ buffer = []
185
+ async for chunk in response:
186
+ if chunk.choices[0].delta.content is not None:
187
+ buffer.append(chunk.choices[0].delta.content)
188
+ if len(buffer) >= 10 or any(c in '.,!?\n' for c in buffer[-1]):
189
+ collected_chunks.extend(buffer)
190
+ yield "".join(collected_chunks)
191
+ buffer = []
192
+
193
+ if buffer: # Yield any remaining content
194
+ collected_chunks.extend(buffer)
195
+ yield "".join(collected_chunks)
196
+
197
+ except Exception as e:
198
+ error_msg = str(e)
199
+ print(f"Error getting model response: {error_msg}")
200
+ if "464" in error_msg:
201
+ yield "Error: Authentication failed. Please check your API key and try again."
202
+ elif "Internal Server Error" in error_msg:
203
+ yield "Error: The server encountered an internal error. Please try again later."
204
+ else:
205
+ yield f"Error: Unable to get response from {model_config.name}. {error_msg}"
206
 
207
+ async def chat_wrapper(
208
+ message: str,
209
+ history: List[Dict[str, str]],
210
+ system_message: str,
211
+ max_tokens: int,
212
+ temperature: float,
213
+ top_p: float,
214
+ model_usage_stats: str,
215
+ ):
216
+ complexity = await get_complexity(message)
217
+ stats.update(complexity)
218
+ model_name = MODEL_CONFIGS[complexity].name
219
+
220
+ # Convert history for model
221
+ model_history = []
222
+ for msg in history:
223
+ if isinstance(msg, dict) and "role" in msg and "content" in msg:
224
+ # Clean content
225
+ content = msg["content"]
226
+ if isinstance(content, str):
227
+ content = content.split("\n\n<div")[0]
228
+ model_history.append({"role": msg["role"], "content": content})
229
+
230
+ # Stream the response
231
+ full_response = ""
232
+ async for partial_response in get_model_response(message, model_history, complexity):
233
+ full_response = partial_response
234
+ response_with_info = f"{full_response}\n\n<div class='model-info'>Model: {model_name}</div>"
235
+
236
+ # Update stats display
237
+ stats_text = stats.get_stats()
238
+
239
+ yield [
240
+ *history,
241
+ {"role": "user", "content": message},
242
+ {"role": "assistant", "content": response_with_info}
243
+ ], stats_text
244
 
245
+ with gr.Blocks(
246
+ theme=gr.themes.Soft(
247
+ primary_hue="blue",
248
+ secondary_hue="indigo",
249
+ neutral_hue="slate",
250
+ font=("Inter", "system-ui", "sans-serif")
251
+ ),
252
+ css="""
253
+ .container {
254
+ max-width: 1000px;
255
+ margin: auto;
256
+ padding: 2rem;
257
+ }
258
+ .title {
259
+ text-align: center;
260
+ font-size: 2.5rem;
261
+ font-weight: 600;
262
+ margin: 1rem 0;
263
+ background: linear-gradient(to right, var(--primary-500), var(--secondary-500));
264
+ -webkit-background-clip: text;
265
+ -webkit-text-fill-color: transparent;
266
+ }
267
+ .subtitle {
268
+ text-align: center;
269
+ font-size: 1.1rem;
270
+ color: var(--neutral-700);
271
+ margin-bottom: 2rem;
272
+ font-weight: 400;
273
+ }
274
+ .model-info {
275
+ font-style: italic;
276
+ color: var(--neutral-500);
277
+ font-size: 0.85em;
278
+ margin-top: 1em;
279
+ padding-top: 0.5em;
280
+ border-top: 1px solid var(--neutral-200);
281
+ opacity: 0.8;
282
+ }
283
+ .stats-box {
284
+ margin-top: 1rem;
285
+ padding: 1rem;
286
+ border-radius: 0.75rem;
287
+ background: color-mix(in srgb, var(--background-fill) 80%, transparent);
288
+ border: 1px solid var(--neutral-200);
289
+ font-family: monospace;
290
+ white-space: pre-line;
291
+ }
292
+ .message.assistant {
293
+ padding-bottom: 1.5em !important;
294
+ }
295
+ """
296
+ ) as demo:
297
+ with gr.Column(elem_classes="container"):
298
+ gr.Markdown("# AI Model Router", elem_classes="title")
299
+ gr.Markdown(
300
+ "Your message will be routed to the appropriate AI model based on complexity.",
301
+ elem_classes="subtitle"
302
+ )
303
+
304
+ chatbot = gr.Chatbot(
305
+ value=[],
306
+ bubble_full_width=False,
307
+ show_label=False,
308
+ height=450,
309
+ container=True,
310
+ type="messages"
311
+ )
312
+
313
+ with gr.Row():
314
+ txt = gr.Textbox(
315
+ show_label=False,
316
+ placeholder="Enter your message here...",
317
+ container=False,
318
+ scale=7
319
+ )
320
+ clear = gr.ClearButton(
321
+ [txt, chatbot],
322
+ scale=1,
323
+ variant="secondary",
324
+ size="sm"
325
+ )
326
+
327
+ with gr.Accordion("Advanced Settings", open=False):
328
+ system_message = gr.Textbox(value="You are a helpful AI assistant.", label="System message")
329
+ max_tokens = gr.Slider(minimum=16, maximum=4096, value=2048, step=1, label="Max Tokens")
330
+ temperature = gr.Slider(minimum=0, maximum=2, value=0.7, step=0.1, label="Temperature")
331
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="Top P")
332
+
333
+ stats_display = gr.Textbox(
334
+ value=stats.get_stats(),
335
+ label="Model Usage Statistics",
336
+ interactive=False,
337
+ elem_classes="stats-box"
338
+ )
339
+
340
+ # Set up event handler for streaming
341
+ txt.submit(
342
+ chat_wrapper,
343
+ [txt, chatbot, system_message, max_tokens, temperature, top_p, stats_display],
344
+ [chatbot, stats_display],
345
+ ).then(
346
+ lambda: "",
347
+ None,
348
+ [txt],
349
+ )
350
 
351
  if __name__ == "__main__":
352
+ demo.queue().launch()