oai_chat / transcription.py
ndurner's picture
gpt-4o-transcribe-diarize support
9000568
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Any, Iterable, Iterator, Sequence
MODEL_CONFIG = {
"whisper": {
"api_model": "whisper-1",
"response_format": "text",
"use_prompt": True,
"chunking_strategy": None,
"supports_stream": False,
},
"gpt-4o-transcribe-diarize": {
"api_model": "gpt-4o-transcribe-diarize",
"response_format": "diarized_json",
"use_prompt": False,
"chunking_strategy": "auto",
"supports_stream": True,
},
}
@dataclass
class TranscriptionUpdate:
text: str
is_final: bool
prompt_append: str | None
def stream_transcriptions(
client: Any,
model_key: str,
message: Any,
history: Iterable[Any],
system_prompt: str,
) -> Iterator[str]:
if model_key not in MODEL_CONFIG:
raise ValueError(f"Unsupported transcription model: {model_key}")
config = MODEL_CONFIG[model_key]
prompt = _build_prompt(history, system_prompt)
message_text, files = _message_fields(message)
if config["use_prompt"] and message_text:
prompt += message_text
if not files:
return
completed: list[tuple[str, str]] = []
last_payload: str | None = None
for file in files:
audio_path = _field(file, "path")
if not audio_path:
if isinstance(file, str):
audio_path = file
if not audio_path:
continue
filename = os.path.basename(audio_path)
builder = _TranscriptBuilder(model_key)
for update in _stream_single_file(client, config, prompt, audio_path, builder):
if not update.text:
continue
current_blocks = completed + [(filename, update.text)]
payload = _assemble_transcript(current_blocks)
if payload != last_payload:
last_payload = payload
yield payload
if update.is_final:
completed.append((filename, update.text))
if config["use_prompt"] and update.prompt_append:
prompt = prompt + f"\n{update.prompt_append}"
break
else:
final_text = builder.formatted_text()
if final_text:
completed.append((filename, final_text))
payload = _assemble_transcript(completed)
if payload != last_payload:
yield payload
if config["use_prompt"]:
prompt = prompt + f"\n{final_text}"
def _stream_single_file(
client: Any,
config: dict[str, Any],
prompt: str,
audio_path: str,
builder: "_TranscriptBuilder",
) -> Iterator[TranscriptionUpdate]:
request_kwargs = {
"model": config["api_model"],
"response_format": config["response_format"],
}
if config["use_prompt"]:
request_kwargs["prompt"] = prompt
if config["chunking_strategy"]:
request_kwargs["chunking_strategy"] = config["chunking_strategy"]
with open(audio_path, "rb") as fh:
request_kwargs["file"] = fh
response = client.audio.transcriptions.create(
stream=config["supports_stream"], **request_kwargs
)
if config["supports_stream"]:
yield from builder.consume_iter(response)
else:
yield builder.consume_snapshot(response)
def _assemble_transcript(blocks: Sequence[tuple[str, str]]) -> str:
parts = []
for filename, text in blocks:
body = text.rstrip()
parts.append(f"``` transcript {filename}\n{body}\n```")
return "\n".join(parts)
def _build_prompt(history: Iterable[Any], system_prompt: str) -> str:
prompt = system_prompt or ""
for msg in history or []:
role = _field(msg, "role")
if role not in ("user", "assistant"):
continue
content = _field(msg, "content")
if isinstance(content, tuple) or content is None:
continue
prompt += f"\n{content}"
return prompt
def _format_transcription_text(
model_key: str, text: str | None, segments: Sequence[dict[str, Any]] | None
) -> str:
if model_key == "whisper":
return text or ""
if model_key == "gpt-4o-transcribe-diarize":
if segments:
turns: list[str] = []
prev_speaker: str | None = None
for seg in segments:
speaker = (seg.get("speaker") or "Speaker").strip()
seg_text = (seg.get("text") or "").strip()
if seg_text:
if speaker != prev_speaker or not turns:
turns.append(f"{speaker}: {seg_text}")
else:
turns[-1] = f"{turns[-1]} {seg_text}".strip()
prev_speaker = speaker
if turns:
return "\n\n".join(turns)
return text or ""
raise ValueError(f"Unhandled transcription model formatting: {model_key}")
class _TranscriptBuilder:
def __init__(self, model_key: str) -> None:
self.model_key = model_key
self._text_chunks: list[str] = []
self._segments: list[dict[str, Any]] = []
self._segment_ids: set[Any] = set()
self._final_text: str | None = None
self._final_segments: list[dict[str, Any]] | None = None
self._finalized = False
def consume_iter(self, stream: Any) -> Iterator[TranscriptionUpdate]:
if not hasattr(stream, "__iter__"):
yield self.consume_snapshot(stream)
return
for event in stream:
update = self._ingest(event, assume_final=False)
if update is None:
continue
yield update
if not self._finalized:
yield self._finalize_update()
def consume_snapshot(self, snapshot: Any) -> TranscriptionUpdate:
update = self._ingest(snapshot, assume_final=True)
if update is not None:
return update
return self._finalize_update()
def formatted_text(self) -> str:
text = self._final_text
if text is None and self._text_chunks:
text = "".join(self._text_chunks)
segments = self._final_segments or self._segments
return _format_transcription_text(self.model_key, text, segments)
def _ingest(self, obj: Any, assume_final: bool) -> TranscriptionUpdate | None:
data = _to_dict(obj)
changed, is_final = self._apply_event(data, assume_final=assume_final)
if not changed:
return None
formatted = self.formatted_text()
append = formatted if (is_final or assume_final) and formatted else None
return TranscriptionUpdate(text=formatted, is_final=is_final or assume_final, prompt_append=append)
def _apply_event(self, data: dict[str, Any], assume_final: bool) -> tuple[bool, bool]:
event_type = data.get("type")
changed = False
is_final = False
if event_type == "transcript.text.delta":
delta = data.get("delta")
if isinstance(delta, str) and delta:
self._text_chunks.append(delta)
changed = True
if event_type == "transcript.text.segment":
segment_payload = data.get("segment") or data
segment = _normalize_segment(segment_payload)
seg_id = segment.get("id")
if seg_id is None or seg_id not in self._segment_ids:
if seg_id is not None:
self._segment_ids.add(seg_id)
self._segments.append(segment)
changed = True
if event_type == "transcript.text.done":
self._capture_final(data)
changed = True
is_final = True
if not changed:
text_value = data.get("text")
segments_value = data.get("segments")
if isinstance(text_value, str) and text_value:
self._final_text = text_value
changed = True
if isinstance(segments_value, list) and segments_value:
self._final_segments = [_normalize_segment(seg) for seg in segments_value]
changed = True
if changed:
is_final = True
if assume_final:
is_final = True
if is_final:
self._finalized = True
return changed, is_final
def _capture_final(self, data: dict[str, Any]) -> None:
text_value = data.get("text")
if isinstance(text_value, str) and text_value:
self._final_text = text_value
segments_value = data.get("segments")
if isinstance(segments_value, list) and segments_value:
self._final_segments = [_normalize_segment(seg) for seg in segments_value]
def _finalize_update(self) -> TranscriptionUpdate:
formatted = self.formatted_text()
self._finalized = True
append = formatted if formatted else None
return TranscriptionUpdate(text=formatted, is_final=True, prompt_append=append)
def _normalize_segment(segment: Any) -> dict[str, Any]:
data = _to_dict(segment)
speaker = data.get("speaker")
if isinstance(speaker, str):
data["speaker"] = speaker.strip()
text = data.get("text")
if isinstance(text, str):
data["text"] = text.strip()
return data
def _message_fields(message: Any) -> tuple[str | None, Any]:
return _field(message, "text"), _field(message, "files")
def _field(obj: Any, key: str) -> Any:
if isinstance(obj, dict):
return obj.get(key)
return getattr(obj, key, None)
def _to_dict(obj: Any) -> dict[str, Any]:
if isinstance(obj, dict):
return obj
if isinstance(obj, str):
return {"text": obj}
if hasattr(obj, "model_dump"):
try:
return obj.model_dump()
except Exception:
pass
if hasattr(obj, "to_dict"):
try:
return obj.to_dict()
except Exception:
pass
if hasattr(obj, "__dict__"):
return {k: v for k, v in obj.__dict__.items() if not k.startswith("_")}
return {}