yanyihan-xiaomi commited on
Commit
0d38c81
·
1 Parent(s): 7c27dbf

Refactor app.py and update requirements.txt

Browse files

- Removed unused imports and refactored environment variable handling in app.py.
- Updated gradio version in requirements.txt for compatibility.

Files changed (2) hide show
  1. app.py +71 -90
  2. requirements.txt +3 -3
app.py CHANGED
@@ -1,18 +1,14 @@
1
- import argparse
2
  import queue
 
3
  import time
4
  from threading import Thread
5
- from typing import Callable, Literal, override
6
- import os
7
 
8
  import fastrtc
9
- from fastrtc import get_cloudflare_turn_credentials_async
10
  import gradio as gr
11
  import httpx
12
  import numpy as np
13
- from pydantic import BaseModel
14
- import random
15
-
16
 
17
  from api_schema import (
18
  AbortController,
@@ -28,61 +24,66 @@ from api_schema import (
28
  )
29
 
30
  HF_TOKEN = os.getenv("HF_TOKEN")
31
- if HF_TOKEN is None:
32
- print(
33
- "⚠️ [WARNING] HF_TOKEN environment variable not found.\n"
34
- "WebRTC connections may fail on Hugging Face Spaces because TURN service cannot be used.\n"
35
- "💡 Solution: Go to your Hugging Face Space → Settings → Secrets, "
36
- "add a variable named HF_TOKEN or HF_ACCESS_TOKEN with your personal access token (with at least 'read' permission)."
37
- )
38
- else:
39
- print(" [INFO] HF_TOKEN detected. WebRTC will use Hugging Face TURN service for connectivity.")
40
-
41
-
42
- url_prefix = os.getenv("URL_PREFIX")
43
- server_number = int(os.getenv("NUM_SERVER"))
44
-
45
- deployment_server = []
46
- for i in range(1, server_number+1):
47
- url = url_prefix + str(i) + ".hf.space"
48
- deployment_server.append(url)
49
-
50
-
51
- class Args(BaseModel):
52
- host: str
53
- port: int
54
- concurrency_limit: int
55
- share: bool
56
- debug: bool
57
- chat_server: str
58
- tag: str | None = None
59
-
60
- @classmethod
61
- def parse_args(cls):
62
- parser = argparse.ArgumentParser(description="Xiaomi MiMo-Audio Chat")
63
- parser.add_argument("--host", default="0.0.0.0")
64
- parser.add_argument("--port", type=int, default=8087)
65
- parser.add_argument("--concurrency-limit", type=int, default=32)
66
- parser.add_argument("--share", action="store_true")
67
- parser.add_argument("--debug", action="store_true")
68
- parser.add_argument(
69
- "-S",
70
- "--chat-server",
71
- dest="chat_server",
72
- type=str,
73
- default="deployment_docker_1",
74
- )
75
- parser.add_argument("--tag", type=str)
76
 
77
- args = parser.parse_args()
78
- return cls.model_validate(vars(args))
79
 
80
- def chat_server_url(self):
81
- return deployment_server[random.randint(0,server_number-1)]
82
- # if self.chat_server in global_chat_server_map:
83
- # return global_chat_server_map[self.chat_server]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # return self.chat_server
86
 
87
  class NeverVAD(fastrtc.PauseDetectionModel):
88
  def vad(self, *_args, **_kwargs):
@@ -152,7 +153,6 @@ class ReplyOnMuted(fastrtc.ReplyOnPause):
152
  return False
153
 
154
 
155
-
156
  class ConversationManager:
157
  def __init__(self, assistant_style: AssistantStyle | None = None):
158
  self.conversation = TokenizedConversation(messages=[])
@@ -269,6 +269,7 @@ class ConversationManager:
269
  except queue.Empty:
270
  yield None
271
 
 
272
  def get_microphone_svg(muted: bool | None = None):
273
  muted_svg = '<line x1="1" y1="1" x2="23" y2="23"></line>' if muted else ""
274
  return f"""
@@ -309,8 +310,6 @@ def new_chat_id():
309
 
310
 
311
  def main():
312
- args = Args.parse_args()
313
-
314
  print("Starting WebRTC server")
315
 
316
  conversations: dict[str, ConversationManager] = {}
@@ -330,23 +329,17 @@ def main():
330
  Thread(target=cleanup_idle_conversations, daemon=True).start()
331
 
332
  def get_preset_list(category: Literal["character", "voice"]) -> list[str]:
333
- url = httpx.URL(args.chat_server_url()).join(f"/preset/{category}")
334
- headers = {
335
- "Authorization": f"Bearer {HF_TOKEN}" # <-- 加上 token
336
- }
337
  with httpx.Client() as client:
338
- response = client.get(url, headers=headers)
339
  if response.status_code == 200:
340
  return PresetOptions.model_validate_json(response.text).options
341
  return ["[default]"]
342
 
343
  def get_model_name() -> str:
344
- url = httpx.URL(args.chat_server_url()).join("/model-name")
345
- headers = {
346
- "Authorization": f"Bearer {HF_TOKEN}" # <-- 加上 token
347
- }
348
  with httpx.Client() as client:
349
- response = client.get(url, headers=headers)
350
  if response.status_code == 200:
351
  return ModelNameResponse.model_validate_json(response.text).model_name
352
  return "unknown"
@@ -354,8 +347,6 @@ def main():
354
  def load_initial_data():
355
  model_name = get_model_name()
356
  title = f"Xiaomi MiMo-Audio WebRTC (model: {model_name})"
357
- if args.tag is not None:
358
- title = f"{args.tag} - {title}"
359
  character_choices = get_preset_list("character")
360
  voice_choices = get_preset_list("voice")
361
  return (
@@ -371,12 +362,6 @@ def main():
371
  preset_voice: str | None,
372
  custom_character_prompt: str | None,
373
  ):
374
- headers = {
375
- "Authorization": f"Bearer {HF_TOKEN}" # <-- 加上 token
376
- }
377
- # deprecate gc
378
- # with httpx.Client() as client:
379
- # client.get(httpx.URL(args.chat_server_url()).join("/gc"), headers=headers)
380
  nonlocal conversations
381
 
382
  if webrtc_id not in conversations:
@@ -416,7 +401,7 @@ def main():
416
  yield additional_outputs()
417
 
418
  try:
419
- url = httpx.URL(args.chat_server_url()).join("/audio-chat")
420
  for chunk in manager.chat(
421
  url,
422
  chat_id,
@@ -463,8 +448,6 @@ def main():
463
  yield additional_outputs()
464
 
465
  title = "Xiaomi MiMo-Audio WebRTC"
466
- if args.tag is not None:
467
- title = f"{args.tag} - {title}"
468
 
469
  with gr.Blocks(title=title) as demo:
470
  title_markdown = gr.Markdown(f"# {title}")
@@ -482,9 +465,7 @@ def main():
482
  modality="audio",
483
  mode="send-receive",
484
  full_screen=False,
485
- rtc_configuration=get_cloudflare_turn_credentials_async
486
- # server_rtc_configuration=get_hf_turn_credentials(ttl=600 * 1000),
487
- # rtc_configuration=get_hf_turn_credentials,
488
  )
489
  output_text = gr.Textbox(label="Output", lines=3, interactive=False)
490
  status_text = gr.Textbox(label="Status", lines=1, interactive=False)
@@ -529,13 +510,13 @@ def main():
529
  preset_voice_dropdown,
530
  custom_character_prompt,
531
  ],
532
- concurrency_limit=args.concurrency_limit,
533
  outputs=[chat],
534
  )
535
  chat.on_additional_outputs(
536
  lambda *args: args,
537
  outputs=[output_text, status_text, collected_audio],
538
- concurrency_limit=args.concurrency_limit,
539
  show_progress="hidden",
540
  )
541
 
@@ -545,9 +526,9 @@ def main():
545
  outputs=[title_markdown, preset_character_dropdown, preset_voice_dropdown],
546
  )
547
  demo.queue(
548
- default_concurrency_limit=args.concurrency_limit,
549
  )
550
-
551
  demo.launch()
552
 
553
 
 
1
+ import os
2
  import queue
3
+ import random
4
  import time
5
  from threading import Thread
6
+ from typing import Any, Callable, Literal, override
 
7
 
8
  import fastrtc
 
9
  import gradio as gr
10
  import httpx
11
  import numpy as np
 
 
 
12
 
13
  from api_schema import (
14
  AbortController,
 
24
  )
25
 
26
  HF_TOKEN = os.getenv("HF_TOKEN")
27
+ SERVER_LIST = os.getenv("SERVER_LIST")
28
+ TURN_KEY_ID = os.getenv("TURN_KEY_ID")
29
+ TURN_KEY_API_TOKEN = os.getenv("TURN_KEY_API_TOKEN")
30
+ CONCURRENCY_LIMIT = os.getenv("CONCURRENCY_LIMIT")
31
+
32
+
33
+ assert SERVER_LIST is not None, "SERVER_LIST environment variable is required."
34
+ assert TURN_KEY_ID is not None and TURN_KEY_API_TOKEN is not None, (
35
+ "TURN_KEY_ID and TURN_KEY_API_TOKEN environment variables are required "
36
+ )
37
+
38
+ deployment_server = [
39
+ server_url.strip() for server_url in SERVER_LIST.split(",") if server_url.strip()
40
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ assert len(deployment_server) > 0, "SERVER_LIST must contain at least one server URL."
 
43
 
44
+ default_concurrency_limit = 32
45
+ try:
46
+ concurrency_limit = (
47
+ int(CONCURRENCY_LIMIT)
48
+ if CONCURRENCY_LIMIT is not None
49
+ else default_concurrency_limit
50
+ )
51
+ except ValueError:
52
+ concurrency_limit = default_concurrency_limit
53
+
54
+
55
+ def chat_server_url(pathname: str = "/") -> httpx.URL:
56
+ n = len(deployment_server)
57
+ server_idx = random.randint(0, n - 1)
58
+ host = deployment_server[server_idx]
59
+ return httpx.URL(host).join(pathname)
60
+
61
+
62
+ def auth_headers() -> dict[str, str]:
63
+ if HF_TOKEN is None:
64
+ return {}
65
+ return {"Authorization": f"Bearer {HF_TOKEN}"}
66
+
67
+
68
+ def get_cloudflare_turn_credentials(
69
+ ttl: int = 1200, # 20 minutes
70
+ ) -> dict[str, Any]:
71
+ with httpx.Client() as client:
72
+ response = client.post(
73
+ f"https://rtc.live.cloudflare.com/v1/turn/keys/{TURN_KEY_ID}/credentials/generate-ice-servers",
74
+ headers={
75
+ "Authorization": f"Bearer {TURN_KEY_API_TOKEN}",
76
+ "Content-Type": "application/json",
77
+ },
78
+ json={"ttl": ttl},
79
+ )
80
+ if response.is_success:
81
+ return response.json()
82
+ else:
83
+ raise Exception(
84
+ f"Failed to get TURN credentials: {response.status_code} {response.text}"
85
+ )
86
 
 
87
 
88
  class NeverVAD(fastrtc.PauseDetectionModel):
89
  def vad(self, *_args, **_kwargs):
 
153
  return False
154
 
155
 
 
156
  class ConversationManager:
157
  def __init__(self, assistant_style: AssistantStyle | None = None):
158
  self.conversation = TokenizedConversation(messages=[])
 
269
  except queue.Empty:
270
  yield None
271
 
272
+
273
  def get_microphone_svg(muted: bool | None = None):
274
  muted_svg = '<line x1="1" y1="1" x2="23" y2="23"></line>' if muted else ""
275
  return f"""
 
310
 
311
 
312
  def main():
 
 
313
  print("Starting WebRTC server")
314
 
315
  conversations: dict[str, ConversationManager] = {}
 
329
  Thread(target=cleanup_idle_conversations, daemon=True).start()
330
 
331
  def get_preset_list(category: Literal["character", "voice"]) -> list[str]:
332
+ url = chat_server_url(f"/preset/{category}")
 
 
 
333
  with httpx.Client() as client:
334
+ response = client.get(url, headers=auth_headers())
335
  if response.status_code == 200:
336
  return PresetOptions.model_validate_json(response.text).options
337
  return ["[default]"]
338
 
339
  def get_model_name() -> str:
340
+ url = chat_server_url("/model-name")
 
 
 
341
  with httpx.Client() as client:
342
+ response = client.get(url, headers=auth_headers())
343
  if response.status_code == 200:
344
  return ModelNameResponse.model_validate_json(response.text).model_name
345
  return "unknown"
 
347
  def load_initial_data():
348
  model_name = get_model_name()
349
  title = f"Xiaomi MiMo-Audio WebRTC (model: {model_name})"
 
 
350
  character_choices = get_preset_list("character")
351
  voice_choices = get_preset_list("voice")
352
  return (
 
362
  preset_voice: str | None,
363
  custom_character_prompt: str | None,
364
  ):
 
 
 
 
 
 
365
  nonlocal conversations
366
 
367
  if webrtc_id not in conversations:
 
401
  yield additional_outputs()
402
 
403
  try:
404
+ url = chat_server_url("/audio-chat")
405
  for chunk in manager.chat(
406
  url,
407
  chat_id,
 
448
  yield additional_outputs()
449
 
450
  title = "Xiaomi MiMo-Audio WebRTC"
 
 
451
 
452
  with gr.Blocks(title=title) as demo:
453
  title_markdown = gr.Markdown(f"# {title}")
 
465
  modality="audio",
466
  mode="send-receive",
467
  full_screen=False,
468
+ rtc_configuration=get_cloudflare_turn_credentials,
 
 
469
  )
470
  output_text = gr.Textbox(label="Output", lines=3, interactive=False)
471
  status_text = gr.Textbox(label="Status", lines=1, interactive=False)
 
510
  preset_voice_dropdown,
511
  custom_character_prompt,
512
  ],
513
+ concurrency_limit=concurrency_limit,
514
  outputs=[chat],
515
  )
516
  chat.on_additional_outputs(
517
  lambda *args: args,
518
  outputs=[output_text, status_text, collected_audio],
519
+ concurrency_limit=concurrency_limit,
520
  show_progress="hidden",
521
  )
522
 
 
526
  outputs=[title_markdown, preset_character_dropdown, preset_voice_dropdown],
527
  )
528
  demo.queue(
529
+ default_concurrency_limit=concurrency_limit,
530
  )
531
+
532
  demo.launch()
533
 
534
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  fastapi==0.116.1
2
  pydantic==2.11.7
3
- fastrtc[vad]==0.0.33
4
- gradio==5.35.0
5
- httpx==0.28.1
 
1
  fastapi==0.116.1
2
  pydantic==2.11.7
3
+ fastrtc==0.0.33
4
+ gradio==5.44.1
5
+ httpx==0.28.1