|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Iterable, Iterator, Sequence |
|
|
|
|
|
|
|
|
MODEL_CONFIG = { |
|
|
"whisper": { |
|
|
"api_model": "whisper-1", |
|
|
"response_format": "text", |
|
|
"use_prompt": True, |
|
|
"chunking_strategy": None, |
|
|
"supports_stream": False, |
|
|
}, |
|
|
"gpt-4o-transcribe-diarize": { |
|
|
"api_model": "gpt-4o-transcribe-diarize", |
|
|
"response_format": "diarized_json", |
|
|
"use_prompt": False, |
|
|
"chunking_strategy": "auto", |
|
|
"supports_stream": True, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TranscriptionUpdate: |
|
|
text: str |
|
|
is_final: bool |
|
|
prompt_append: str | None |
|
|
|
|
|
|
|
|
def stream_transcriptions( |
|
|
client: Any, |
|
|
model_key: str, |
|
|
message: Any, |
|
|
history: Iterable[Any], |
|
|
system_prompt: str, |
|
|
) -> Iterator[str]: |
|
|
if model_key not in MODEL_CONFIG: |
|
|
raise ValueError(f"Unsupported transcription model: {model_key}") |
|
|
|
|
|
config = MODEL_CONFIG[model_key] |
|
|
prompt = _build_prompt(history, system_prompt) |
|
|
message_text, files = _message_fields(message) |
|
|
if config["use_prompt"] and message_text: |
|
|
prompt += message_text |
|
|
|
|
|
if not files: |
|
|
return |
|
|
|
|
|
completed: list[tuple[str, str]] = [] |
|
|
last_payload: str | None = None |
|
|
|
|
|
for file in files: |
|
|
audio_path = _field(file, "path") |
|
|
if not audio_path: |
|
|
if isinstance(file, str): |
|
|
audio_path = file |
|
|
if not audio_path: |
|
|
continue |
|
|
|
|
|
filename = os.path.basename(audio_path) |
|
|
builder = _TranscriptBuilder(model_key) |
|
|
|
|
|
for update in _stream_single_file(client, config, prompt, audio_path, builder): |
|
|
if not update.text: |
|
|
continue |
|
|
current_blocks = completed + [(filename, update.text)] |
|
|
payload = _assemble_transcript(current_blocks) |
|
|
if payload != last_payload: |
|
|
last_payload = payload |
|
|
yield payload |
|
|
if update.is_final: |
|
|
completed.append((filename, update.text)) |
|
|
if config["use_prompt"] and update.prompt_append: |
|
|
prompt = prompt + f"\n{update.prompt_append}" |
|
|
break |
|
|
else: |
|
|
final_text = builder.formatted_text() |
|
|
if final_text: |
|
|
completed.append((filename, final_text)) |
|
|
payload = _assemble_transcript(completed) |
|
|
if payload != last_payload: |
|
|
yield payload |
|
|
if config["use_prompt"]: |
|
|
prompt = prompt + f"\n{final_text}" |
|
|
|
|
|
|
|
|
def _stream_single_file( |
|
|
client: Any, |
|
|
config: dict[str, Any], |
|
|
prompt: str, |
|
|
audio_path: str, |
|
|
builder: "_TranscriptBuilder", |
|
|
) -> Iterator[TranscriptionUpdate]: |
|
|
request_kwargs = { |
|
|
"model": config["api_model"], |
|
|
"response_format": config["response_format"], |
|
|
} |
|
|
if config["use_prompt"]: |
|
|
request_kwargs["prompt"] = prompt |
|
|
if config["chunking_strategy"]: |
|
|
request_kwargs["chunking_strategy"] = config["chunking_strategy"] |
|
|
|
|
|
with open(audio_path, "rb") as fh: |
|
|
request_kwargs["file"] = fh |
|
|
response = client.audio.transcriptions.create( |
|
|
stream=config["supports_stream"], **request_kwargs |
|
|
) |
|
|
|
|
|
if config["supports_stream"]: |
|
|
yield from builder.consume_iter(response) |
|
|
else: |
|
|
yield builder.consume_snapshot(response) |
|
|
|
|
|
|
|
|
def _assemble_transcript(blocks: Sequence[tuple[str, str]]) -> str: |
|
|
parts = [] |
|
|
for filename, text in blocks: |
|
|
body = text.rstrip() |
|
|
parts.append(f"``` transcript {filename}\n{body}\n```") |
|
|
return "\n".join(parts) |
|
|
|
|
|
|
|
|
def _build_prompt(history: Iterable[Any], system_prompt: str) -> str: |
|
|
prompt = system_prompt or "" |
|
|
for msg in history or []: |
|
|
role = _field(msg, "role") |
|
|
if role not in ("user", "assistant"): |
|
|
continue |
|
|
content = _field(msg, "content") |
|
|
if isinstance(content, tuple) or content is None: |
|
|
continue |
|
|
prompt += f"\n{content}" |
|
|
return prompt |
|
|
|
|
|
|
|
|
def _format_transcription_text( |
|
|
model_key: str, text: str | None, segments: Sequence[dict[str, Any]] | None |
|
|
) -> str: |
|
|
if model_key == "whisper": |
|
|
return text or "" |
|
|
|
|
|
if model_key == "gpt-4o-transcribe-diarize": |
|
|
if segments: |
|
|
turns: list[str] = [] |
|
|
prev_speaker: str | None = None |
|
|
for seg in segments: |
|
|
speaker = (seg.get("speaker") or "Speaker").strip() |
|
|
seg_text = (seg.get("text") or "").strip() |
|
|
if seg_text: |
|
|
if speaker != prev_speaker or not turns: |
|
|
turns.append(f"{speaker}: {seg_text}") |
|
|
else: |
|
|
turns[-1] = f"{turns[-1]} {seg_text}".strip() |
|
|
prev_speaker = speaker |
|
|
if turns: |
|
|
return "\n\n".join(turns) |
|
|
return text or "" |
|
|
|
|
|
raise ValueError(f"Unhandled transcription model formatting: {model_key}") |
|
|
|
|
|
|
|
|
class _TranscriptBuilder: |
|
|
def __init__(self, model_key: str) -> None: |
|
|
self.model_key = model_key |
|
|
self._text_chunks: list[str] = [] |
|
|
self._segments: list[dict[str, Any]] = [] |
|
|
self._segment_ids: set[Any] = set() |
|
|
self._final_text: str | None = None |
|
|
self._final_segments: list[dict[str, Any]] | None = None |
|
|
self._finalized = False |
|
|
|
|
|
def consume_iter(self, stream: Any) -> Iterator[TranscriptionUpdate]: |
|
|
if not hasattr(stream, "__iter__"): |
|
|
yield self.consume_snapshot(stream) |
|
|
return |
|
|
|
|
|
for event in stream: |
|
|
update = self._ingest(event, assume_final=False) |
|
|
if update is None: |
|
|
continue |
|
|
yield update |
|
|
|
|
|
if not self._finalized: |
|
|
yield self._finalize_update() |
|
|
|
|
|
def consume_snapshot(self, snapshot: Any) -> TranscriptionUpdate: |
|
|
update = self._ingest(snapshot, assume_final=True) |
|
|
if update is not None: |
|
|
return update |
|
|
return self._finalize_update() |
|
|
|
|
|
def formatted_text(self) -> str: |
|
|
text = self._final_text |
|
|
if text is None and self._text_chunks: |
|
|
text = "".join(self._text_chunks) |
|
|
segments = self._final_segments or self._segments |
|
|
return _format_transcription_text(self.model_key, text, segments) |
|
|
|
|
|
def _ingest(self, obj: Any, assume_final: bool) -> TranscriptionUpdate | None: |
|
|
data = _to_dict(obj) |
|
|
changed, is_final = self._apply_event(data, assume_final=assume_final) |
|
|
if not changed: |
|
|
return None |
|
|
formatted = self.formatted_text() |
|
|
append = formatted if (is_final or assume_final) and formatted else None |
|
|
return TranscriptionUpdate(text=formatted, is_final=is_final or assume_final, prompt_append=append) |
|
|
|
|
|
def _apply_event(self, data: dict[str, Any], assume_final: bool) -> tuple[bool, bool]: |
|
|
event_type = data.get("type") |
|
|
changed = False |
|
|
is_final = False |
|
|
|
|
|
if event_type == "transcript.text.delta": |
|
|
delta = data.get("delta") |
|
|
if isinstance(delta, str) and delta: |
|
|
self._text_chunks.append(delta) |
|
|
changed = True |
|
|
|
|
|
if event_type == "transcript.text.segment": |
|
|
segment_payload = data.get("segment") or data |
|
|
segment = _normalize_segment(segment_payload) |
|
|
seg_id = segment.get("id") |
|
|
if seg_id is None or seg_id not in self._segment_ids: |
|
|
if seg_id is not None: |
|
|
self._segment_ids.add(seg_id) |
|
|
self._segments.append(segment) |
|
|
changed = True |
|
|
|
|
|
if event_type == "transcript.text.done": |
|
|
self._capture_final(data) |
|
|
changed = True |
|
|
is_final = True |
|
|
|
|
|
if not changed: |
|
|
text_value = data.get("text") |
|
|
segments_value = data.get("segments") |
|
|
if isinstance(text_value, str) and text_value: |
|
|
self._final_text = text_value |
|
|
changed = True |
|
|
if isinstance(segments_value, list) and segments_value: |
|
|
self._final_segments = [_normalize_segment(seg) for seg in segments_value] |
|
|
changed = True |
|
|
if changed: |
|
|
is_final = True |
|
|
|
|
|
if assume_final: |
|
|
is_final = True |
|
|
if is_final: |
|
|
self._finalized = True |
|
|
|
|
|
return changed, is_final |
|
|
|
|
|
def _capture_final(self, data: dict[str, Any]) -> None: |
|
|
text_value = data.get("text") |
|
|
if isinstance(text_value, str) and text_value: |
|
|
self._final_text = text_value |
|
|
segments_value = data.get("segments") |
|
|
if isinstance(segments_value, list) and segments_value: |
|
|
self._final_segments = [_normalize_segment(seg) for seg in segments_value] |
|
|
|
|
|
def _finalize_update(self) -> TranscriptionUpdate: |
|
|
formatted = self.formatted_text() |
|
|
self._finalized = True |
|
|
append = formatted if formatted else None |
|
|
return TranscriptionUpdate(text=formatted, is_final=True, prompt_append=append) |
|
|
|
|
|
|
|
|
def _normalize_segment(segment: Any) -> dict[str, Any]: |
|
|
data = _to_dict(segment) |
|
|
speaker = data.get("speaker") |
|
|
if isinstance(speaker, str): |
|
|
data["speaker"] = speaker.strip() |
|
|
text = data.get("text") |
|
|
if isinstance(text, str): |
|
|
data["text"] = text.strip() |
|
|
return data |
|
|
|
|
|
|
|
|
def _message_fields(message: Any) -> tuple[str | None, Any]: |
|
|
return _field(message, "text"), _field(message, "files") |
|
|
|
|
|
|
|
|
def _field(obj: Any, key: str) -> Any: |
|
|
if isinstance(obj, dict): |
|
|
return obj.get(key) |
|
|
return getattr(obj, key, None) |
|
|
|
|
|
|
|
|
def _to_dict(obj: Any) -> dict[str, Any]: |
|
|
if isinstance(obj, dict): |
|
|
return obj |
|
|
if isinstance(obj, str): |
|
|
return {"text": obj} |
|
|
if hasattr(obj, "model_dump"): |
|
|
try: |
|
|
return obj.model_dump() |
|
|
except Exception: |
|
|
pass |
|
|
if hasattr(obj, "to_dict"): |
|
|
try: |
|
|
return obj.to_dict() |
|
|
except Exception: |
|
|
pass |
|
|
if hasattr(obj, "__dict__"): |
|
|
return {k: v for k, v in obj.__dict__.items() if not k.startswith("_")} |
|
|
return {} |
|
|
|