Ratnesh-dev's picture
Refactor pyannote token handling
5964a77
import time
import uuid
import warnings
from collections import OrderedDict
from typing import Any
import gradio as gr
import spaces
from src.constants import PARAKEET_V3
from src.diarization_service import run_chunked_diarization
from src.merge_service import merge_parakeet_pyannote_outputs
from src.models.parakeet_model import preload_parakeet_model, run_parakeet
from src.models.pyannote_community_model import preload_pyannote_pipeline, run_pyannote_community_chunk
from src.utils import get_audio_duration_seconds
# Suppress a known deprecation warning emitted by a transitive dependency in spaces.
warnings.filterwarnings(
"ignore",
message=r"`torch\.distributed\.reduce_op` is deprecated, please use `torch\.distributed\.ReduceOp` instead",
category=FutureWarning,
)
_PRELOAD_ERRORS: dict[str, str] = {}
_DEBUG_RUNS: "OrderedDict[str, dict[str, Any]]" = OrderedDict()
_MAX_DEBUG_RUNS = 10
_LAST_DEBUG_RUN_ID: str | None = None
RUN_COMPLETE_PIPELINE_OUTPUT_SCHEMA: dict[str, Any] = {
"type": "object",
"description": "Merged transcript output from Parakeet (word timestamps) and Pyannote diarization.",
"properties": {
"summary": {
"type": "object",
"properties": {
"diarization_key_used": {"type": "string", "example": "exclusive_speaker_diarization"},
"parakeet_word_count": {"type": "integer"},
"pyannote_segment_count": {"type": "integer"},
"turn_count": {"type": "integer"},
"assigned_word_count": {"type": "integer"},
"unassigned_word_count": {"type": "integer"},
},
"required": [
"diarization_key_used",
"parakeet_word_count",
"pyannote_segment_count",
"turn_count",
"assigned_word_count",
"unassigned_word_count",
],
},
"turns": {
"type": "array",
"items": {
"type": "object",
"properties": {
"speaker": {"type": "string", "example": "SPEAKER_02"},
"start": {"type": "number", "example": 40.72},
"end": {"type": "number", "example": 514.0},
"text": {"type": "string"},
},
"required": ["speaker", "start", "end", "text"],
},
},
"transcript_text": {"type": "string"},
},
"required": ["summary", "turns", "transcript_text"],
}
RUN_COMPLETE_PIPELINE_OUTPUT_EXAMPLE: dict[str, Any] = {
"summary": {
"diarization_key_used": "exclusive_speaker_diarization",
"parakeet_word_count": 1234,
"pyannote_segment_count": 42,
"turn_count": 39,
"assigned_word_count": 1219,
"unassigned_word_count": 15,
},
"turns": [
{
"speaker": "SPEAKER_00",
"start": 0.0,
"end": 12.34,
"text": "Good morning and welcome to the earnings call.",
},
{
"speaker": "SPEAKER_01",
"start": 12.34,
"end": 19.02,
"text": "Thank you. Let us begin with quarterly highlights.",
},
],
"transcript_text": "[0.00 - 12.34] SPEAKER_00: Good morning ...",
}
RUN_COMPLETE_PIPELINE_INPUT_SCHEMA: dict[str, Any] = {
"type": "object",
"properties": {
"audio_file": {"type": "file", "description": "Audio file upload"},
"huggingface_token": {"type": "string", "description": "HF access token for pyannote model"},
},
"required": ["audio_file", "huggingface_token"],
}
def _preload_model(model_label: str, preload_fn) -> None:
try:
preload_fn()
except Exception as exc:
_PRELOAD_ERRORS[model_label] = str(exc)
def _raise_preload_error_if_any(model_label: str) -> None:
message = _PRELOAD_ERRORS.get(model_label)
if message:
raise gr.Error(
f"Model preload failed for {model_label}. "
"Check startup logs and dependencies. "
f"Details: {message}"
)
def _store_debug_payload(payload: dict[str, Any]) -> str:
global _LAST_DEBUG_RUN_ID
run_id = str(uuid.uuid4())
_DEBUG_RUNS[run_id] = payload
_LAST_DEBUG_RUN_ID = run_id
while len(_DEBUG_RUNS) > _MAX_DEBUG_RUNS:
_DEBUG_RUNS.popitem(last=False)
return run_id
def _parse_main_request(
audio_file: str | None,
huggingface_token: str | None,
) -> str:
if audio_file is None:
raise gr.Error("No audio file submitted. Upload an audio file first.")
if not huggingface_token or not huggingface_token.strip():
raise gr.Error("huggingface_token is required for pyannote/speaker-diarization-community-1.")
return huggingface_token.strip()
# Global setup (outside @spaces.GPU) so setup cost is not charged to ZeroGPU inference window.
_preload_model(PARAKEET_V3, preload_parakeet_model)
@spaces.GPU(duration=120)
def _gpu_infer_parakeet(audio_file: str, duration_seconds: float | None):
gpu_start = time.perf_counter()
result = run_parakeet(
audio_file=audio_file,
language=None,
model_options={},
duration_seconds=duration_seconds,
)
gpu_end = time.perf_counter()
return {
"raw_output": result["raw_output"],
"zerogpu_timing": {
"gpu_window_seconds": round(gpu_end - gpu_start, 4),
**result.get("timing", {}),
},
}
@spaces.GPU(duration=120)
def _gpu_infer_pyannote_chunk(
audio_file: str,
huggingface_token: str,
):
gpu_start = time.perf_counter()
result = run_pyannote_community_chunk(
audio_file=audio_file,
huggingface_token=huggingface_token,
)
gpu_end = time.perf_counter()
return {
"raw_output": result["raw_output"],
"zerogpu_timing": {
"gpu_window_seconds": round(gpu_end - gpu_start, 4),
**result.get("timing", {}),
},
}
def run_complete_pipeline(
audio_file: str,
huggingface_token: str,
):
huggingface_token = _parse_main_request(audio_file, huggingface_token)
_raise_preload_error_if_any(PARAKEET_V3)
started_at = time.perf_counter()
duration_seconds = get_audio_duration_seconds(audio_file)
# 1) Parakeet transcription on ZeroGPU.
parakeet_gpu_result = _gpu_infer_parakeet(
audio_file=audio_file,
duration_seconds=duration_seconds,
)
parakeet_response = {
"model": PARAKEET_V3,
"task": "transcribe",
"audio_file": str(audio_file),
"postprocess_prompt": None,
"model_options": {},
"zerogpu_timing": parakeet_gpu_result["zerogpu_timing"],
"raw_output": parakeet_gpu_result["raw_output"],
"timestamp_granularity": "word",
}
# 2) Pyannote diarization on ZeroGPU (chunked only when needed).
pyannote_model_options = {
"long_audio_chunk_threshold_s": 7200,
"chunk_duration_s": 7200,
"chunk_overlap_s": 0,
}
preload_pyannote_pipeline(huggingface_token=huggingface_token)
def gpu_chunk_runner(audio_file: str, model_options: dict[str, Any]) -> dict[str, Any]:
del model_options
return _gpu_infer_pyannote_chunk(
audio_file=audio_file,
huggingface_token=huggingface_token,
)
pyannote_response = run_chunked_diarization(
audio_file=audio_file,
model_options=pyannote_model_options,
gpu_chunk_runner=gpu_chunk_runner,
)
# 3) CPU-side postprocessing outside ZeroGPU.
merged_transcript = merge_parakeet_pyannote_outputs(
parakeet_response=parakeet_response,
pyannote_response=pyannote_response,
diarization_key="exclusive_speaker_diarization",
)
total_gpu_window_seconds = float(parakeet_response["zerogpu_timing"].get("gpu_window_seconds", 0.0)) + float(
pyannote_response.get("zerogpu_timing", {}).get("gpu_window_seconds", 0.0)
)
total_inference_seconds = float(parakeet_response["zerogpu_timing"].get("inference_seconds", 0.0)) + float(
pyannote_response.get("zerogpu_timing", {}).get("inference_seconds", 0.0)
)
finished_at = time.perf_counter()
debug_payload = {
"pipeline_timing": {
"total_wall_clock_seconds": round(finished_at - started_at, 4),
"zerogpu_gpu_window_seconds_total": round(total_gpu_window_seconds, 4),
"zerogpu_inference_seconds_total": round(total_inference_seconds, 4),
},
"inputs": {
"audio_file": str(audio_file),
"huggingface_token_provided": bool(huggingface_token),
},
"parakeet_response": parakeet_response,
"pyannote_response": pyannote_response,
"merged_transcript": merged_transcript,
}
_store_debug_payload(debug_payload)
# Return merged transcript JSON (OpenAI cleanup is intentionally local/off-space).
return merged_transcript
def get_debug_output(run_id: str | None):
if run_id and run_id.strip():
payload = _DEBUG_RUNS.get(run_id.strip())
if payload is None:
raise gr.Error(f"Unknown run_id: {run_id}")
return {"run_id": run_id.strip(), "debug": payload}
if _LAST_DEBUG_RUN_ID is None:
raise gr.Error("No debug payload available yet. Run /run_complete_pipeline first.")
return {"run_id": _LAST_DEBUG_RUN_ID, "debug": _DEBUG_RUNS[_LAST_DEBUG_RUN_ID]}
def get_run_complete_pipeline_schema() -> dict[str, Any]:
return {
"api_name": "/run_complete_pipeline",
"input_schema": RUN_COMPLETE_PIPELINE_INPUT_SCHEMA,
"output_schema": RUN_COMPLETE_PIPELINE_OUTPUT_SCHEMA,
"output_example": RUN_COMPLETE_PIPELINE_OUTPUT_EXAMPLE,
"notes": [
"Use /get_debug_output to fetch raw model payloads and timing.",
"The production route returns only merged transcript JSON.",
],
}
with gr.Blocks(title="Parakeet + Pyannote Pipeline") as demo:
gr.Markdown(
"# End-to-end transcript pipeline\n"
"Runs Parakeet transcription, Pyannote diarization, then merges into a combined transcript JSON."
)
audio_file = gr.Audio(
sources=["upload"],
type="filepath",
label="Audio file",
)
huggingface_token = gr.Textbox(
label="HuggingFace token",
type="password",
)
run_btn = gr.Button("Run full pipeline")
output = gr.JSON(label="Combined transcript JSON")
run_btn.click(
fn=run_complete_pipeline,
inputs=[audio_file, huggingface_token],
outputs=output,
api_name="run_complete_pipeline",
api_description=(
"Run Parakeet + Pyannote and return merged transcript JSON.\n"
"Response shape:\n"
"{\n"
' "summary": {\n'
' "diarization_key_used": str,\n'
' "parakeet_word_count": int,\n'
' "pyannote_segment_count": int,\n'
' "turn_count": int,\n'
' "assigned_word_count": int,\n'
' "unassigned_word_count": int\n'
" },\n"
' "turns": [{"speaker": str, "start": float, "end": float, "text": str}],\n'
' "transcript_text": str\n'
"}\n"
"For full machine-readable schema + example, call /get_run_complete_pipeline_schema."
),
)
with gr.Row():
debug_run_id = gr.Textbox(label="Debug run_id (optional)")
debug_btn = gr.Button("Get debug output")
debug_output = gr.JSON(label="Debug output (raw + benchmark)")
debug_btn.click(
fn=get_debug_output,
inputs=[debug_run_id],
outputs=debug_output,
api_name="get_debug_output",
api_description=(
"Return latest (or selected) debug payload including raw Parakeet/Pyannote outputs "
"and aggregated pipeline timing."
),
)
with gr.Row(visible=False):
schema_btn = gr.Button("get_run_complete_pipeline_schema")
schema_output = gr.JSON(label="run_complete_pipeline schema", visible=False)
schema_btn.click(
fn=get_run_complete_pipeline_schema,
inputs=None,
outputs=schema_output,
api_name="get_run_complete_pipeline_schema",
api_description="Return input/output schema contract for /run_complete_pipeline.",
)
demo.queue(default_concurrency_limit=1).launch(ssr_mode=False, theme=gr.themes.Ocean())