littlebird13 commited on
Commit
d641c8e
·
verified ·
1 Parent(s): 7448529

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -0
app.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Simultaneous speech translation over FastRTC using DashScope Qwen3 LiveTranslate.
2
+ - Streams mic audio (16k PCM16) to DashScope Realtime
3
+ - Receives translated text deltas and 24k PCM16 TTS audio
4
+ - Plays audio via FastRTC and shows text in a Gradio Chatbot
5
+ Set DASHSCOPE_API_KEY in the environment before running.
6
+ """
7
+ import os
8
+ import time
9
+ import base64
10
+ import asyncio
11
+ import json
12
+ import secrets
13
+ import signal
14
+ from pathlib import Path
15
+
16
+ import gradio as gr
17
+ import numpy as np
18
+ from dotenv import load_dotenv
19
+ from fastapi import FastAPI
20
+ from fastapi.responses import HTMLResponse, StreamingResponse
21
+ from fastrtc import (
22
+ AdditionalOutputs,
23
+ AsyncStreamHandler,
24
+ Stream,
25
+ get_cloudflare_turn_credentials_async,
26
+ wait_for_item,
27
+ )
28
+ from gradio.utils import get_space
29
+ from websockets.asyncio.client import connect
30
+
31
+ load_dotenv()
32
+
33
+ cur_dir = Path(__file__).parent
34
+
35
+ API_KEY = os.environ['API_KEY']
36
+ API_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime?model=qwen3-livetranslate-flash-realtime"
37
+ VOICES = ["Cherry", "Nofish", "Jada", "Dylan", "Sunny", "Peter", "Kiki", "Eric"]
38
+
39
+ if not API_KEY:
40
+ raise RuntimeError("Missing DASHSCOPE_API_KEY environment variable.")
41
+ headers = {"Authorization": "Bearer " + API_KEY}
42
+ LANG_MAP = {
43
+ "en": "English",
44
+ "zh": "Chinese",
45
+ "ru": "Russian",
46
+ "fr": "French",
47
+ "de": "German",
48
+ "pt": "Portuguese",
49
+ "es": "Spanish",
50
+ "it": "Italian",
51
+ "ko": "Korean",
52
+ "ja": "Japanese",
53
+ "yue": "Cantonese",
54
+ "id": "Indonesian",
55
+ "vi": "Vietnamese",
56
+ "th": "Thai",
57
+ "ar": "Arabic",
58
+ "hi": "Hindi",
59
+ "el": "Greek",
60
+ "tr": "Turkish"
61
+ }
62
+ LANG_MAP_REVERSE = {v: k for k, v in LANG_MAP.items()}
63
+ # SRC_LANGUAGES = ["en", "zh", "ru", "fr", "de", "pt", "es", "it", "ko", "ja", "yue", "id", "vi", "th", "ar", "hi", "el", "tr"] # 使用相同的语言列表
64
+ # TARGET_LANGUAGES = ["en", "zh", "ru", "fr", "de", "pt", "es", "it", "ko", "ja", "yue", "id", "vi", "th", "ar"]
65
+
66
+ SRC_LANGUAGES = [LANG_MAP[code] for code in ["en", "zh", "ru", "fr", "de", "pt", "es", "it", "ko", "ja", "yue", "id", "vi", "th", "ar", "hi", "el", "tr"]]
67
+ TARGET_LANGUAGES = [LANG_MAP[code] for code in ["en", "zh", "ru", "fr", "de", "pt", "es", "it", "ko", "ja", "yue", "id", "vi", "th", "ar"]]
68
+
69
+
70
+ class LiveTranslateHandler(AsyncStreamHandler):
71
+ def __init__(self) -> None:
72
+ super().__init__(
73
+ expected_layout="mono",
74
+ output_sample_rate=24_000,
75
+ input_sample_rate=16_000,
76
+ )
77
+ self.connection = None
78
+ self.output_queue = asyncio.Queue()
79
+
80
+ def copy(self):
81
+ return LiveTranslateHandler()
82
+
83
+ @staticmethod
84
+ def msg_id() -> str:
85
+ return f"event_{secrets.token_hex(10)}"
86
+
87
+ async def start_up(self):
88
+ try:
89
+ await self.wait_for_args()
90
+ args = self.latest_args
91
+ src_language_name = args[2] if len(args) > 2 else "Chinese" # 现在 dropdown 返回的是全称
92
+ target_language_name = args[3] if len(args) > 3 else "English"
93
+ src_language_code = LANG_MAP_REVERSE[src_language_name]
94
+ target_language_code = LANG_MAP_REVERSE[target_language_name]
95
+
96
+ # src_language = args[2] if len(args) > 2 else "zh" # 新增源语言参数
97
+ # target_language = args[3] if len(args) > 3 else "en"
98
+ voice_id = args[4] if len(args) > 4 else "Cherry"
99
+
100
+ if src_language_code == target_language_code:
101
+ print(f"⚠️ 源语言和目标语言相同({target_language_name}),将以复述模式运行")
102
+
103
+ async with connect(API_URL, additional_headers=headers) as conn:
104
+ self.client = conn
105
+ await conn.send(
106
+ json.dumps(
107
+ {
108
+ "event_id": self.msg_id(),
109
+ "type": "session.update",
110
+ "session": {
111
+ "modalities": ["text", "audio"],
112
+ "voice": voice_id,
113
+ "input_audio_format": "pcm16",
114
+ "output_audio_format": "pcm16",
115
+ "translation": {
116
+ "source_language": src_language_code, # 添加源语言
117
+ "language": target_language_code
118
+ }
119
+ },
120
+ }
121
+ )
122
+ )
123
+ self.connection = conn
124
+
125
+ async for data in self.connection:
126
+ event = json.loads(data)
127
+ if "type" not in event:
128
+ continue
129
+ event_type = event["type"]
130
+
131
+ if event_type == "response.audio_transcript.delta":
132
+ # 增量字幕
133
+ text = event.get("transcript", "")
134
+ if text:
135
+ await self.output_queue.put(
136
+ AdditionalOutputs({"role": "assistant", "content": text})
137
+ )
138
+
139
+ # elif event_type in ("response.text.text", "response.audio_transcript.text"):
140
+ # # 中间结果 + stash(stash通常是句子完整缓存)
141
+ # stash_text = event.get("stash", "")
142
+ # text_field = event.get("text", "")
143
+ # if stash_text or text_field:
144
+ # await self.output_queue.put(
145
+ # AdditionalOutputs({"role": "assistant", "content": stash_text or text_field})
146
+ # )
147
+
148
+ elif event_type == "response.audio_transcript.done":
149
+ # 最终完整句子
150
+ transcript = event.get("transcript", "")
151
+ if transcript:
152
+ await self.output_queue.put(
153
+ AdditionalOutputs({"role": "assistant", "content": transcript})
154
+ )
155
+
156
+ elif event_type == "response.audio.delta":
157
+ audio_b64 = event.get("delta", "")
158
+ if audio_b64:
159
+ audio_data = base64.b64decode(audio_b64)
160
+ audio_array = np.frombuffer(audio_data, dtype=np.int16).reshape(1, -1)
161
+ await self.output_queue.put(
162
+ (self.output_sample_rate, audio_array)
163
+ )
164
+
165
+ except Exception as e:
166
+ print(f"Connection error: {e}")
167
+ await self.shutdown()
168
+
169
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
170
+ if not self.connection:
171
+ return
172
+ _, array = frame
173
+ array = array.squeeze()
174
+ audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
175
+ await self.connection.send(
176
+ json.dumps(
177
+ {
178
+ "event_id": self.msg_id(),
179
+ "type": "input_audio_buffer.append",
180
+ "audio": audio_message,
181
+ }
182
+ )
183
+ )
184
+
185
+ async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
186
+ return await wait_for_item(self.output_queue)
187
+
188
+ async def shutdown(self) -> None:
189
+ if self.connection:
190
+ await self.connection.close()
191
+ self.connection = None
192
+
193
+ # 清空队列
194
+ while not self.output_queue.empty():
195
+ self.output_queue.get_nowait()
196
+
197
+
198
+ def update_chatbot(chatbot: list[dict], response: dict):
199
+ chatbot.append(response)
200
+ return chatbot
201
+
202
+
203
+ chatbot = gr.Chatbot(type="messages")
204
+ src_language = gr.Dropdown(
205
+ choices=SRC_LANGUAGES,
206
+ value="Chinese", # 改成全称
207
+ type="value",
208
+ label="Source Language"
209
+ )
210
+ language = gr.Dropdown(
211
+ choices=TARGET_LANGUAGES,
212
+ value="English", # 改成全称
213
+ type="value",
214
+ label="Target Language"
215
+ )
216
+ voice = gr.Dropdown(choices=VOICES, value=VOICES[0], type="value", label="Voice")
217
+ latest_message = gr.Textbox(type="text", visible=False)
218
+
219
+ # 可选:暂时禁用 TURN 配置进行测试
220
+ rtc_config = get_cloudflare_turn_credentials_async if get_space() else None
221
+ # rtc_config = None # 取消注释可禁用 TURN 测试
222
+
223
+ stream = Stream(
224
+ LiveTranslateHandler(),
225
+ mode="send-receive",
226
+ modality="audio",
227
+ additional_inputs=[src_language, language, voice, chatbot], # 添加 src_language
228
+ additional_outputs=[chatbot],
229
+ additional_outputs_handler=update_chatbot,
230
+ rtc_configuration=rtc_config,
231
+ concurrency_limit=5 if get_space() else None,
232
+ time_limit=90 if get_space() else None,
233
+ )
234
+
235
+
236
+ app = FastAPI()
237
+
238
+ stream.mount(app)
239
+
240
+
241
+ @app.get("/")
242
+ async def _():
243
+ rtc_config = await get_cloudflare_turn_credentials_async() if get_space() else None
244
+ html_content = (cur_dir / "index.html").read_text()
245
+ html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
246
+ return HTMLResponse(content=html_content)
247
+
248
+
249
+ @app.get("/outputs")
250
+ def _(webrtc_id: str):
251
+ async def output_stream():
252
+ import json
253
+
254
+ async for output in stream.output_stream(webrtc_id):
255
+ s = json.dumps(output.args[0])
256
+ yield f"event: output\ndata: {s}\n\n"
257
+
258
+ return StreamingResponse(output_stream(), media_type="text/event-stream")
259
+
260
+
261
+ def handle_exit(sig, frame):
262
+ print("Shutting down gracefully...")
263
+ # 可扩展为执行更多清理逻辑
264
+
265
+
266
+ signal.signal(signal.SIGINT, handle_exit)
267
+ signal.signal(signal.SIGTERM, handle_exit)
268
+
269
+ if __name__ == "__main__":
270
+ import os
271
+
272
+ if (mode := os.getenv("MODE")) == "UI":
273
+ stream.ui.launch(server_port=7860)
274
+ elif mode == "PHONE":
275
+ stream.fastphone(host="0.0.0.0", port=7860)
276
+ else:
277
+ import uvicorn
278
+
279
+ uvicorn.run(app, host="0.0.0.0", port=7860)