Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
0d38c81
1
Parent(s):
7c27dbf
Refactor app.py and update requirements.txt
Browse files- Removed unused imports and refactored environment variable handling in app.py.
- Updated gradio version in requirements.txt for compatibility.
- app.py +71 -90
- requirements.txt +3 -3
app.py
CHANGED
|
@@ -1,18 +1,14 @@
|
|
| 1 |
-
import
|
| 2 |
import queue
|
|
|
|
| 3 |
import time
|
| 4 |
from threading import Thread
|
| 5 |
-
from typing import Callable, Literal, override
|
| 6 |
-
import os
|
| 7 |
|
| 8 |
import fastrtc
|
| 9 |
-
from fastrtc import get_cloudflare_turn_credentials_async
|
| 10 |
import gradio as gr
|
| 11 |
import httpx
|
| 12 |
import numpy as np
|
| 13 |
-
from pydantic import BaseModel
|
| 14 |
-
import random
|
| 15 |
-
|
| 16 |
|
| 17 |
from api_schema import (
|
| 18 |
AbortController,
|
|
@@ -28,61 +24,66 @@ from api_schema import (
|
|
| 28 |
)
|
| 29 |
|
| 30 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
deployment_server = []
|
| 46 |
-
for i in range(1, server_number+1):
|
| 47 |
-
url = url_prefix + str(i) + ".hf.space"
|
| 48 |
-
deployment_server.append(url)
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
class Args(BaseModel):
|
| 52 |
-
host: str
|
| 53 |
-
port: int
|
| 54 |
-
concurrency_limit: int
|
| 55 |
-
share: bool
|
| 56 |
-
debug: bool
|
| 57 |
-
chat_server: str
|
| 58 |
-
tag: str | None = None
|
| 59 |
-
|
| 60 |
-
@classmethod
|
| 61 |
-
def parse_args(cls):
|
| 62 |
-
parser = argparse.ArgumentParser(description="Xiaomi MiMo-Audio Chat")
|
| 63 |
-
parser.add_argument("--host", default="0.0.0.0")
|
| 64 |
-
parser.add_argument("--port", type=int, default=8087)
|
| 65 |
-
parser.add_argument("--concurrency-limit", type=int, default=32)
|
| 66 |
-
parser.add_argument("--share", action="store_true")
|
| 67 |
-
parser.add_argument("--debug", action="store_true")
|
| 68 |
-
parser.add_argument(
|
| 69 |
-
"-S",
|
| 70 |
-
"--chat-server",
|
| 71 |
-
dest="chat_server",
|
| 72 |
-
type=str,
|
| 73 |
-
default="deployment_docker_1",
|
| 74 |
-
)
|
| 75 |
-
parser.add_argument("--tag", type=str)
|
| 76 |
|
| 77 |
-
|
| 78 |
-
return cls.model_validate(vars(args))
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
# return self.chat_server
|
| 86 |
|
| 87 |
class NeverVAD(fastrtc.PauseDetectionModel):
|
| 88 |
def vad(self, *_args, **_kwargs):
|
|
@@ -152,7 +153,6 @@ class ReplyOnMuted(fastrtc.ReplyOnPause):
|
|
| 152 |
return False
|
| 153 |
|
| 154 |
|
| 155 |
-
|
| 156 |
class ConversationManager:
|
| 157 |
def __init__(self, assistant_style: AssistantStyle | None = None):
|
| 158 |
self.conversation = TokenizedConversation(messages=[])
|
|
@@ -269,6 +269,7 @@ class ConversationManager:
|
|
| 269 |
except queue.Empty:
|
| 270 |
yield None
|
| 271 |
|
|
|
|
| 272 |
def get_microphone_svg(muted: bool | None = None):
|
| 273 |
muted_svg = '<line x1="1" y1="1" x2="23" y2="23"></line>' if muted else ""
|
| 274 |
return f"""
|
|
@@ -309,8 +310,6 @@ def new_chat_id():
|
|
| 309 |
|
| 310 |
|
| 311 |
def main():
|
| 312 |
-
args = Args.parse_args()
|
| 313 |
-
|
| 314 |
print("Starting WebRTC server")
|
| 315 |
|
| 316 |
conversations: dict[str, ConversationManager] = {}
|
|
@@ -330,23 +329,17 @@ def main():
|
|
| 330 |
Thread(target=cleanup_idle_conversations, daemon=True).start()
|
| 331 |
|
| 332 |
def get_preset_list(category: Literal["character", "voice"]) -> list[str]:
|
| 333 |
-
url =
|
| 334 |
-
headers = {
|
| 335 |
-
"Authorization": f"Bearer {HF_TOKEN}" # <-- 加上 token
|
| 336 |
-
}
|
| 337 |
with httpx.Client() as client:
|
| 338 |
-
response = client.get(url, headers=
|
| 339 |
if response.status_code == 200:
|
| 340 |
return PresetOptions.model_validate_json(response.text).options
|
| 341 |
return ["[default]"]
|
| 342 |
|
| 343 |
def get_model_name() -> str:
|
| 344 |
-
url =
|
| 345 |
-
headers = {
|
| 346 |
-
"Authorization": f"Bearer {HF_TOKEN}" # <-- 加上 token
|
| 347 |
-
}
|
| 348 |
with httpx.Client() as client:
|
| 349 |
-
response = client.get(url, headers=
|
| 350 |
if response.status_code == 200:
|
| 351 |
return ModelNameResponse.model_validate_json(response.text).model_name
|
| 352 |
return "unknown"
|
|
@@ -354,8 +347,6 @@ def main():
|
|
| 354 |
def load_initial_data():
|
| 355 |
model_name = get_model_name()
|
| 356 |
title = f"Xiaomi MiMo-Audio WebRTC (model: {model_name})"
|
| 357 |
-
if args.tag is not None:
|
| 358 |
-
title = f"{args.tag} - {title}"
|
| 359 |
character_choices = get_preset_list("character")
|
| 360 |
voice_choices = get_preset_list("voice")
|
| 361 |
return (
|
|
@@ -371,12 +362,6 @@ def main():
|
|
| 371 |
preset_voice: str | None,
|
| 372 |
custom_character_prompt: str | None,
|
| 373 |
):
|
| 374 |
-
headers = {
|
| 375 |
-
"Authorization": f"Bearer {HF_TOKEN}" # <-- 加上 token
|
| 376 |
-
}
|
| 377 |
-
# deprecate gc
|
| 378 |
-
# with httpx.Client() as client:
|
| 379 |
-
# client.get(httpx.URL(args.chat_server_url()).join("/gc"), headers=headers)
|
| 380 |
nonlocal conversations
|
| 381 |
|
| 382 |
if webrtc_id not in conversations:
|
|
@@ -416,7 +401,7 @@ def main():
|
|
| 416 |
yield additional_outputs()
|
| 417 |
|
| 418 |
try:
|
| 419 |
-
url =
|
| 420 |
for chunk in manager.chat(
|
| 421 |
url,
|
| 422 |
chat_id,
|
|
@@ -463,8 +448,6 @@ def main():
|
|
| 463 |
yield additional_outputs()
|
| 464 |
|
| 465 |
title = "Xiaomi MiMo-Audio WebRTC"
|
| 466 |
-
if args.tag is not None:
|
| 467 |
-
title = f"{args.tag} - {title}"
|
| 468 |
|
| 469 |
with gr.Blocks(title=title) as demo:
|
| 470 |
title_markdown = gr.Markdown(f"# {title}")
|
|
@@ -482,9 +465,7 @@ def main():
|
|
| 482 |
modality="audio",
|
| 483 |
mode="send-receive",
|
| 484 |
full_screen=False,
|
| 485 |
-
rtc_configuration=
|
| 486 |
-
# server_rtc_configuration=get_hf_turn_credentials(ttl=600 * 1000),
|
| 487 |
-
# rtc_configuration=get_hf_turn_credentials,
|
| 488 |
)
|
| 489 |
output_text = gr.Textbox(label="Output", lines=3, interactive=False)
|
| 490 |
status_text = gr.Textbox(label="Status", lines=1, interactive=False)
|
|
@@ -529,13 +510,13 @@ def main():
|
|
| 529 |
preset_voice_dropdown,
|
| 530 |
custom_character_prompt,
|
| 531 |
],
|
| 532 |
-
concurrency_limit=
|
| 533 |
outputs=[chat],
|
| 534 |
)
|
| 535 |
chat.on_additional_outputs(
|
| 536 |
lambda *args: args,
|
| 537 |
outputs=[output_text, status_text, collected_audio],
|
| 538 |
-
concurrency_limit=
|
| 539 |
show_progress="hidden",
|
| 540 |
)
|
| 541 |
|
|
@@ -545,9 +526,9 @@ def main():
|
|
| 545 |
outputs=[title_markdown, preset_character_dropdown, preset_voice_dropdown],
|
| 546 |
)
|
| 547 |
demo.queue(
|
| 548 |
-
default_concurrency_limit=
|
| 549 |
)
|
| 550 |
-
|
| 551 |
demo.launch()
|
| 552 |
|
| 553 |
|
|
|
|
| 1 |
+
import os
|
| 2 |
import queue
|
| 3 |
+
import random
|
| 4 |
import time
|
| 5 |
from threading import Thread
|
| 6 |
+
from typing import Any, Callable, Literal, override
|
|
|
|
| 7 |
|
| 8 |
import fastrtc
|
|
|
|
| 9 |
import gradio as gr
|
| 10 |
import httpx
|
| 11 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
from api_schema import (
|
| 14 |
AbortController,
|
|
|
|
| 24 |
)
|
| 25 |
|
| 26 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 27 |
+
SERVER_LIST = os.getenv("SERVER_LIST")
|
| 28 |
+
TURN_KEY_ID = os.getenv("TURN_KEY_ID")
|
| 29 |
+
TURN_KEY_API_TOKEN = os.getenv("TURN_KEY_API_TOKEN")
|
| 30 |
+
CONCURRENCY_LIMIT = os.getenv("CONCURRENCY_LIMIT")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
assert SERVER_LIST is not None, "SERVER_LIST environment variable is required."
|
| 34 |
+
assert TURN_KEY_ID is not None and TURN_KEY_API_TOKEN is not None, (
|
| 35 |
+
"TURN_KEY_ID and TURN_KEY_API_TOKEN environment variables are required "
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
deployment_server = [
|
| 39 |
+
server_url.strip() for server_url in SERVER_LIST.split(",") if server_url.strip()
|
| 40 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
assert len(deployment_server) > 0, "SERVER_LIST must contain at least one server URL."
|
|
|
|
| 43 |
|
| 44 |
+
default_concurrency_limit = 32
|
| 45 |
+
try:
|
| 46 |
+
concurrency_limit = (
|
| 47 |
+
int(CONCURRENCY_LIMIT)
|
| 48 |
+
if CONCURRENCY_LIMIT is not None
|
| 49 |
+
else default_concurrency_limit
|
| 50 |
+
)
|
| 51 |
+
except ValueError:
|
| 52 |
+
concurrency_limit = default_concurrency_limit
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def chat_server_url(pathname: str = "/") -> httpx.URL:
|
| 56 |
+
n = len(deployment_server)
|
| 57 |
+
server_idx = random.randint(0, n - 1)
|
| 58 |
+
host = deployment_server[server_idx]
|
| 59 |
+
return httpx.URL(host).join(pathname)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def auth_headers() -> dict[str, str]:
|
| 63 |
+
if HF_TOKEN is None:
|
| 64 |
+
return {}
|
| 65 |
+
return {"Authorization": f"Bearer {HF_TOKEN}"}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_cloudflare_turn_credentials(
|
| 69 |
+
ttl: int = 1200, # 20 minutes
|
| 70 |
+
) -> dict[str, Any]:
|
| 71 |
+
with httpx.Client() as client:
|
| 72 |
+
response = client.post(
|
| 73 |
+
f"https://rtc.live.cloudflare.com/v1/turn/keys/{TURN_KEY_ID}/credentials/generate-ice-servers",
|
| 74 |
+
headers={
|
| 75 |
+
"Authorization": f"Bearer {TURN_KEY_API_TOKEN}",
|
| 76 |
+
"Content-Type": "application/json",
|
| 77 |
+
},
|
| 78 |
+
json={"ttl": ttl},
|
| 79 |
+
)
|
| 80 |
+
if response.is_success:
|
| 81 |
+
return response.json()
|
| 82 |
+
else:
|
| 83 |
+
raise Exception(
|
| 84 |
+
f"Failed to get TURN credentials: {response.status_code} {response.text}"
|
| 85 |
+
)
|
| 86 |
|
|
|
|
| 87 |
|
| 88 |
class NeverVAD(fastrtc.PauseDetectionModel):
|
| 89 |
def vad(self, *_args, **_kwargs):
|
|
|
|
| 153 |
return False
|
| 154 |
|
| 155 |
|
|
|
|
| 156 |
class ConversationManager:
|
| 157 |
def __init__(self, assistant_style: AssistantStyle | None = None):
|
| 158 |
self.conversation = TokenizedConversation(messages=[])
|
|
|
|
| 269 |
except queue.Empty:
|
| 270 |
yield None
|
| 271 |
|
| 272 |
+
|
| 273 |
def get_microphone_svg(muted: bool | None = None):
|
| 274 |
muted_svg = '<line x1="1" y1="1" x2="23" y2="23"></line>' if muted else ""
|
| 275 |
return f"""
|
|
|
|
| 310 |
|
| 311 |
|
| 312 |
def main():
|
|
|
|
|
|
|
| 313 |
print("Starting WebRTC server")
|
| 314 |
|
| 315 |
conversations: dict[str, ConversationManager] = {}
|
|
|
|
| 329 |
Thread(target=cleanup_idle_conversations, daemon=True).start()
|
| 330 |
|
| 331 |
def get_preset_list(category: Literal["character", "voice"]) -> list[str]:
|
| 332 |
+
url = chat_server_url(f"/preset/{category}")
|
|
|
|
|
|
|
|
|
|
| 333 |
with httpx.Client() as client:
|
| 334 |
+
response = client.get(url, headers=auth_headers())
|
| 335 |
if response.status_code == 200:
|
| 336 |
return PresetOptions.model_validate_json(response.text).options
|
| 337 |
return ["[default]"]
|
| 338 |
|
| 339 |
def get_model_name() -> str:
|
| 340 |
+
url = chat_server_url("/model-name")
|
|
|
|
|
|
|
|
|
|
| 341 |
with httpx.Client() as client:
|
| 342 |
+
response = client.get(url, headers=auth_headers())
|
| 343 |
if response.status_code == 200:
|
| 344 |
return ModelNameResponse.model_validate_json(response.text).model_name
|
| 345 |
return "unknown"
|
|
|
|
| 347 |
def load_initial_data():
|
| 348 |
model_name = get_model_name()
|
| 349 |
title = f"Xiaomi MiMo-Audio WebRTC (model: {model_name})"
|
|
|
|
|
|
|
| 350 |
character_choices = get_preset_list("character")
|
| 351 |
voice_choices = get_preset_list("voice")
|
| 352 |
return (
|
|
|
|
| 362 |
preset_voice: str | None,
|
| 363 |
custom_character_prompt: str | None,
|
| 364 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
nonlocal conversations
|
| 366 |
|
| 367 |
if webrtc_id not in conversations:
|
|
|
|
| 401 |
yield additional_outputs()
|
| 402 |
|
| 403 |
try:
|
| 404 |
+
url = chat_server_url("/audio-chat")
|
| 405 |
for chunk in manager.chat(
|
| 406 |
url,
|
| 407 |
chat_id,
|
|
|
|
| 448 |
yield additional_outputs()
|
| 449 |
|
| 450 |
title = "Xiaomi MiMo-Audio WebRTC"
|
|
|
|
|
|
|
| 451 |
|
| 452 |
with gr.Blocks(title=title) as demo:
|
| 453 |
title_markdown = gr.Markdown(f"# {title}")
|
|
|
|
| 465 |
modality="audio",
|
| 466 |
mode="send-receive",
|
| 467 |
full_screen=False,
|
| 468 |
+
rtc_configuration=get_cloudflare_turn_credentials,
|
|
|
|
|
|
|
| 469 |
)
|
| 470 |
output_text = gr.Textbox(label="Output", lines=3, interactive=False)
|
| 471 |
status_text = gr.Textbox(label="Status", lines=1, interactive=False)
|
|
|
|
| 510 |
preset_voice_dropdown,
|
| 511 |
custom_character_prompt,
|
| 512 |
],
|
| 513 |
+
concurrency_limit=concurrency_limit,
|
| 514 |
outputs=[chat],
|
| 515 |
)
|
| 516 |
chat.on_additional_outputs(
|
| 517 |
lambda *args: args,
|
| 518 |
outputs=[output_text, status_text, collected_audio],
|
| 519 |
+
concurrency_limit=concurrency_limit,
|
| 520 |
show_progress="hidden",
|
| 521 |
)
|
| 522 |
|
|
|
|
| 526 |
outputs=[title_markdown, preset_character_dropdown, preset_voice_dropdown],
|
| 527 |
)
|
| 528 |
demo.queue(
|
| 529 |
+
default_concurrency_limit=concurrency_limit,
|
| 530 |
)
|
| 531 |
+
|
| 532 |
demo.launch()
|
| 533 |
|
| 534 |
|
requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
fastapi==0.116.1
|
| 2 |
pydantic==2.11.7
|
| 3 |
-
fastrtc
|
| 4 |
-
gradio==5.
|
| 5 |
-
httpx==0.28.1
|
|
|
|
| 1 |
fastapi==0.116.1
|
| 2 |
pydantic==2.11.7
|
| 3 |
+
fastrtc==0.0.33
|
| 4 |
+
gradio==5.44.1
|
| 5 |
+
httpx==0.28.1
|