EchoFlow / app.py
bphouse's picture
Update app.py
070e964 verified
# app.py
import gradio as gr
import whisper
import json
import shutil
import os
import uuid
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, pipeline
from opencc import OpenCC
# === 模型變數初始化(懶載入)===
whisper_model = None
m2m_model = None
m2m_tokenizer = None
m2m_model_name = "facebook/m2m100_418M"
cc = OpenCC("s2t") # 簡轉繁
# ✅ 使用穩定可用的中文潤飾模型
refiner = pipeline(
"text2text-generation",
model="uer/pegasus-base-chinese-cluecorpussmall"
)
# === 語言對照表 ===
lang_map = {
"自動偵測": None,
"中文": "zh",
"英文": "en",
"日文": "ja",
"法文": "fr",
"西班牙文": "es",
"德文": "de",
"義大利文": "it",
"葡萄牙文": "pt"
}
target_langs = {
"繁體中文": "zh",
"英文": "en",
"日文": "ja",
"法文": "fr",
"西班牙文": "es",
"德文": "de",
"義大利文": "it",
"葡萄牙文": "pt"
}
def lazy_load_models():
global whisper_model, m2m_model, m2m_tokenizer
if whisper_model is None:
whisper_model = whisper.load_model("medium")
if m2m_model is None:
m2m_model = M2M100ForConditionalGeneration.from_pretrained(m2m_model_name)
if m2m_tokenizer is None:
m2m_tokenizer = M2M100Tokenizer.from_pretrained(m2m_model_name)
def get_lang_label(code):
return next((label for label, c in lang_map.items() if c == code), "未知")
def format_timestamp(seconds):
return f"{int(seconds//3600):02}:{int((seconds%3600)//60):02}:{int(seconds%60):02},{int((seconds-int(seconds))*1000):03}"
def break_line(text, max_len=40):
return '\n'.join([text[i:i+max_len] for i in range(0, len(text), max_len)])
def export_files(text, translation, lang, segments, uid):
txt_path = f"transcript_{uid}.txt"
json_path = f"transcript_{uid}.json"
srt_path = f"transcript_{uid}.srt"
with open(txt_path, "w", encoding="utf-8") as f:
f.write(f"語言:{lang}\n\n原文:\n{text}\n\n翻譯:\n{translation}")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({
"language": lang,
"transcript": text,
"translation": translation,
"segments": segments
}, f, ensure_ascii=False, indent=2)
with open(srt_path, "w", encoding="utf-8") as f:
for i, seg in enumerate(segments):
start = format_timestamp(seg["start"])
end = format_timestamp(seg["end"])
f.write(f"{i+1}\n{start} --> {end}\n{break_line(seg['text'])}\n\n")
return txt_path, json_path, srt_path
def translate_text(text, detected_lang, target_lang_label):
try:
src_lang = detected_lang if detected_lang in target_langs.values() else "en"
tgt_lang = target_langs.get(target_lang_label, "zh")
m2m_tokenizer.src_lang = src_lang
encoded = m2m_tokenizer(text, return_tensors="pt")
generated = m2m_model.generate(
**encoded,
forced_bos_token_id=m2m_tokenizer.get_lang_id(tgt_lang)
)
translated = m2m_tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
return cc.convert(translated) if tgt_lang == "zh" else translated
except Exception as e:
return f"(⚠️ 翻譯失敗:{str(e)})"
# === Session Memory ===
last_uid = ""
last_original_text = ""
def refine_translation_from_original():
global last_original_text
if not last_original_text.strip():
return "⚠️ 尚未產生可潤飾的原文"
prompt = f"請將以下內容在不改變原來意思之下,潤飾為更通順自然的中文:\n{last_original_text}"
try:
result = refiner(prompt, max_length=512, do_sample=False)
return result[0]["generated_text"]
except Exception as e:
return f"(⚠️ 潤飾錯誤:{str(e)})"
def transcribe_and_translate(audio_path, lang_label, target_lang_label):
global last_uid, last_original_text
lazy_load_models()
if not audio_path or not os.path.isfile(audio_path):
return "⚠️ 請先錄音或上傳語音檔", "", "", None, None, None, None
ext_allowed = ['.wav', '.mp3', '.m4a']
if not any(audio_path.lower().endswith(ext) for ext in ext_allowed):
return "⚠️ 僅支援 wav, mp3, m4a 格式音訊檔", "", "", None, None, None, None
uid = uuid.uuid4().hex[:8]
last_uid = uid
lang_code = lang_map.get(lang_label)
result = whisper_model.transcribe(audio_path, language=lang_code)
text = result["text"]
last_original_text = text
detected_lang = result["language"]
segments = result.get("segments", [])
translation = translate_text(text, detected_lang, target_lang_label)
txt, jsonf, srt = export_files(text, translation, detected_lang, segments, uid)
audio_filename = f"audio_{uid}.wav"
shutil.copy(audio_path, audio_filename)
return text, get_lang_label(detected_lang), translation, txt, jsonf, srt, audio_filename
def delete_current_session_files():
global last_uid
if not last_uid:
return "⚠️ 尚未產生可刪除的檔案"
deleted = []
for suffix in [".txt", ".json", ".srt"]:
path = f"transcript_{last_uid}{suffix}"
if os.path.exists(path):
os.remove(path)
deleted.append(path)
audio_path = f"audio_{last_uid}.wav"
if os.path.exists(audio_path):
os.remove(audio_path)
deleted.append(audio_path)
return f"✅ 已刪除 {len(deleted)} 筆檔案"
# === Gradio UI ===
with gr.Blocks() as demo:
gr.Markdown("## 🎤 Whisper + 多語翻譯 + 中文潤飾")
recording_ready = gr.State(False)
with gr.Row():
audio_input = gr.Audio(label="🎙️ 上傳或錄音語音檔", type="filepath")
with gr.Row():
lang_dropdown = gr.Dropdown(label="語音語言(可自動偵測)", choices=list(lang_map.keys()), value="自動偵測")
target_lang_dropdown = gr.Dropdown(label="翻譯目標語言", choices=list(target_langs.keys()), value="繁體中文")
start_btn = gr.Button("🚀 開始辨識與翻譯", interactive=False)
original_text = gr.Textbox(label="📝 語音辨識原文", lines=12)
detected_lang = gr.Textbox(label="🌐 偵測語言")
translated_text = gr.Textbox(label="🌸 翻譯結果", lines=8)
refined_text = gr.Textbox(label="🌟 潤飾後內容", lines=8)
file_txt = gr.File(label="📄 TXT")
file_json = gr.File(label="📄 JSON")
file_srt = gr.File(label="🎬 SRT 字幕")
file_audio = gr.File(label="🔊 原始音訊下載")
refine_btn = gr.Button("✨ 潤飾語音辨識原文")
clear_btn = gr.Button("🧹 刪除本次產生檔案")
clear_result = gr.Textbox(label="🧾 系統訊息")
def audio_uploaded(_):
return gr.update(interactive=True), True
audio_input.change(fn=audio_uploaded, inputs=[audio_input], outputs=[start_btn, recording_ready])
start_btn.click(fn=transcribe_and_translate,
inputs=[audio_input, lang_dropdown, target_lang_dropdown],
outputs=[original_text, detected_lang, translated_text,
file_txt, file_json, file_srt, file_audio])
refine_btn.click(fn=refine_translation_from_original, inputs=[], outputs=[refined_text])
clear_btn.click(fn=delete_current_session_files, inputs=[], outputs=[clear_result])
demo.launch()