WhisperX-V2 / app.py
StevenChen16's picture
Update app.py
9a9ac31 verified
raw
history blame
3.37 kB
import spaces
import gradio as gr
import yt_dlp as youtube_dl
import whisperx
import tempfile
import os
import torch
import gc
# WhisperX配置
device = "cuda" #if torch.cuda.is_available() else "cpu"
batch_size = 4
compute_type = "float32"
MODEL_NAME = "large-v3"
YT_LENGTH_LIMIT_S = 3600 # 1 hour YouTube files
# 加载WhisperX模型
@spaces.GPU
def load_whisperx_model():
# 加载 WhisperX 模型
return whisperx.load_model(MODEL_NAME, device=device, compute_type=compute_type)
model = load_whisperx_model()
@spaces.GPU
def transcribe(inputs, task):
if inputs is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
# 加载和转录音频
audio = whisperx.load_audio(inputs)
result = model.transcribe(audio, batch_size=batch_size)
print(result["segments"]) # 未对齐的文本片段
# 释放资源以节省GPU内存
gc.collect()
torch.cuda.empty_cache()
del model
# 加载对齐模型
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
# 说话人分离
diarize_model = whisperx.DiarizationPipeline(use_auth_token="your_huggingface_token", device=device)
result = whisperx.assign_word_speakers(diarize_model, result)
# 格式化输出
transcript = ""
for segment in result['segments']:
speaker = segment.get('speaker', 'Unknown')
transcript += f"{speaker}: {segment['text']}\n"
return transcript
def _return_yt_html_embed(yt_url):
video_id = yt_url.split("?v=")[-1]
HTML_str = (
f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
" </center>"
)
return HTML_str
def download_yt_audio(yt_url, filename):
info_loader = youtube_dl.YoutubeDL()
try:
info = info_loader.extract_info(yt_url, download=False)
except youtube_dl.utils.DownloadError as err:
raise gr.Error(str(err))
file_length = info["duration"]
if file_length > YT_LENGTH_LIMIT_S:
raise gr.Error("YouTube video length exceeds the 1-hour limit.")
ydl_opts = {"outtmpl": filename, "format": "bestaudio[ext=m4a]"}
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
try:
ydl.download([yt_url])
except youtube_dl.utils.ExtractorError as err:
raise gr.Error(str(err))
def yt_transcribe(yt_url, task):
html_embed_str = _return_yt_html_embed(yt_url)
with tempfile.TemporaryDirectory() as tmpdirname:
filepath = os.path.join(tmpdirname, "video.m4a")
download_yt_audio(yt_url, filepath)
result = transcribe(filepath, task)
return html_embed_str, result
# Gradio 界面设置
demo = gr.Blocks()
yt_transcribe_interface = gr.Interface(
fn=yt_transcribe,
inputs=[gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL")],
outputs=["html", "text"],
title="WhisperX: Transcribe YouTube with Speaker Diarization",
description="Transcribe and diarize YouTube videos with WhisperX."
)
with demo:
gr.TabbedInterface([yt_transcribe_interface], ["YouTube"])
demo.launch()