|
import os |
|
import logging |
|
import json |
|
import torch |
|
|
|
import gradio as gr |
|
import numpy as np |
|
from dotenv import load_dotenv |
|
from fastapi import FastAPI |
|
from fastapi.responses import StreamingResponse, HTMLResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastrtc import ( |
|
AdditionalOutputs, |
|
ReplyOnPause, |
|
Stream, |
|
AlgoOptions, |
|
SileroVadOptions, |
|
audio_to_bytes, |
|
get_cloudflare_turn_credentials_async, |
|
) |
|
from transformers import ( |
|
AutoModelForSpeechSeq2Seq, |
|
AutoProcessor, |
|
pipeline, |
|
) |
|
from transformers.utils import is_flash_attn_2_available |
|
|
|
from utils.logger_config import setup_logging |
|
from utils.device import get_device, get_torch_and_np_dtypes |
|
|
|
load_dotenv() |
|
setup_logging() |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
UI_MODE = os.getenv("UI_MODE", "fastapi").lower() |
|
UI_TYPE = os.getenv("UI_TYPE", "base").lower() |
|
APP_MODE = os.getenv("APP_MODE", "local").lower() |
|
MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo") |
|
LANGUAGE = os.getenv("LANGUAGE", "english").lower() |
|
|
|
|
|
device = get_device(force_cpu=False) |
|
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False) |
|
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}") |
|
|
|
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa" |
|
logger.info(f"Using attention: {attention}") |
|
|
|
logger.info(f"Loading Whisper model: {MODEL_ID}") |
|
logger.info(f"Using language: {LANGUAGE}") |
|
|
|
try: |
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=torch_dtype, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True, |
|
attn_implementation=attention |
|
) |
|
model.to(device) |
|
except Exception as e: |
|
logger.error(f"Error loading ASR model: {e}") |
|
logger.error(f"Are you providing a valid model ID? {MODEL_ID}") |
|
raise |
|
|
|
processor = AutoProcessor.from_pretrained(MODEL_ID) |
|
|
|
transcribe_pipeline = pipeline( |
|
task="automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
torch_dtype=torch_dtype, |
|
device=device, |
|
) |
|
if device in ["cuda", "mps"]: |
|
transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune") |
|
|
|
|
|
logger.info("Warming up Whisper model with dummy input") |
|
warmup_audio = np.zeros((16000,), dtype=np_dtype) |
|
transcribe_pipeline(warmup_audio) |
|
logger.info("Model warmup complete") |
|
|
|
async def transcribe(audio: tuple[int, np.ndarray]): |
|
sample_rate, audio_array = audio |
|
logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}") |
|
|
|
outputs = transcribe_pipeline( |
|
audio_to_bytes(audio), |
|
chunk_length_s=3, |
|
batch_size=2, |
|
generate_kwargs={ |
|
'task': 'transcribe', |
|
'language': LANGUAGE, |
|
}, |
|
|
|
) |
|
yield AdditionalOutputs(outputs["text"].strip()) |
|
|
|
logger.info("Initializing FastRTC stream") |
|
stream = Stream( |
|
handler=ReplyOnPause( |
|
transcribe, |
|
algo_options=AlgoOptions( |
|
|
|
audio_chunk_duration=0.6, |
|
|
|
started_talking_threshold=0.1, |
|
|
|
speech_threshold=0.1, |
|
), |
|
model_options=SileroVadOptions( |
|
|
|
threshold=0.5, |
|
|
|
min_speech_duration_ms=250, |
|
|
|
|
|
max_speech_duration_s=3, |
|
|
|
min_silence_duration_ms=100, |
|
|
|
window_size_samples=512, |
|
|
|
speech_pad_ms=200, |
|
), |
|
), |
|
|
|
|
|
|
|
modality="audio", |
|
mode="send", |
|
additional_outputs=[ |
|
gr.Textbox(label="Transcript"), |
|
], |
|
additional_outputs_handler=lambda current, new: current + " " + new, |
|
rtc_configuration=get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None, |
|
concurrency_limit=6 |
|
) |
|
|
|
app = FastAPI() |
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
stream.mount(app) |
|
|
|
@app.get("/") |
|
async def index(): |
|
if UI_TYPE == "base": |
|
html_content = open("static/index.html").read() |
|
elif UI_TYPE == "screen": |
|
html_content = open("static/index-screen.html").read() |
|
|
|
rtc_configuration = await get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None |
|
logger.info(f"RTC configuration: {rtc_configuration}") |
|
html_content = html_content.replace("__INJECTED_RTC_CONFIG__", json.dumps(rtc_configuration)) |
|
return HTMLResponse(content=html_content) |
|
|
|
@app.get("/transcript") |
|
def _(webrtc_id: str): |
|
logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}") |
|
async def output_stream(): |
|
try: |
|
async for output in stream.output_stream(webrtc_id): |
|
transcript = output.args[0] |
|
logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...") |
|
yield f"event: output\ndata: {transcript}\n\n" |
|
except Exception as e: |
|
logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}") |
|
raise |
|
|
|
return StreamingResponse(output_stream(), media_type="text/event-stream") |
|
|
|
if __name__ == "__main__": |
|
if UI_MODE == "gradio": |
|
stream.ui.launch(server_port=7860) |
|
else: |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |