ndurner commited on
Commit
9000568
·
1 Parent(s): 18ccfa1

gpt-4o-transcribe-diarize support

Browse files
Files changed (2) hide show
  1. app.py +12 -31
  2. transcription.py +309 -0
app.py CHANGED
@@ -13,6 +13,7 @@ from mcp_registry import load_registry, get_tools_for_server, call_local_mcp_too
13
  from gradio.components.base import Component
14
  from types import SimpleNamespace
15
  from dotenv import load_dotenv
 
16
 
17
  from doc2json import process_docx
18
  from code_exec import eval_script
@@ -215,36 +216,16 @@ async def bot(message, history, history_openai_format, oai_key, system_prompt, t
215
  api_key=oai_key
216
  )
217
 
218
- if model == "whisper":
219
- result = ""
220
- whisper_prompt = system_prompt
221
- for msg in history:
222
- role = msg.role if hasattr(msg, "role") else msg["role"]
223
- content = msg.content if hasattr(msg, "content") else msg["content"]
224
- if role == "user":
225
- if type(content) is tuple:
226
- pass
227
- else:
228
- whisper_prompt += f"\n{content}"
229
- if role == "assistant":
230
- whisper_prompt += f"\n{content}"
231
-
232
- if message["text"]:
233
- whisper_prompt += message["text"]
234
- if message.files:
235
- for file in message.files:
236
- audio_fn = os.path.basename(file.path)
237
- with open(file.path, "rb") as f:
238
- transcription = client.audio.transcriptions.create(
239
- model="whisper-1",
240
- prompt=whisper_prompt,
241
- file=f,
242
- response_format="text"
243
- )
244
- whisper_prompt += f"\n{transcription}"
245
- result += f"\n``` transcript {audio_fn}\n {transcription}\n```"
246
-
247
- yield gr.ChatMessage(role="assistant", content=result)
248
 
249
  elif model == "gpt-image-1":
250
  if message.get("files"):
@@ -722,7 +703,7 @@ with gr.Blocks(delete_cache=(86400, 86400)) as demo:
722
 
723
  oai_key = gr.Textbox(label="OpenAI API Key", elem_id="oai_key", value=os.environ.get("OPENAI_API_KEY"))
724
  model = gr.Dropdown(label="Model", value="gpt-5-mini", allow_custom_value=True, elem_id="model",
725
- choices=["gpt-5", "gpt-5-mini", "gpt-5-chat-latest", "gpt-5-pro", "gpt-4o", "gpt-4.1", "o3", "o3-pro", "o4-mini", "chatgpt-4o-latest", "gpt-4o-mini", "gpt-4-turbo", "whisper", "gpt-image-1"])
726
  reasoning_effort = gr.Dropdown(label="Reasoning Effort", value="medium", choices=["low", "medium", "high"], elem_id="reasoning_effort")
727
  verbosity = gr.Dropdown(label="Verbosity (GPT-5)", value="medium", choices=["low", "medium", "high"], elem_id="verbosity")
728
  system_prompt = gr.TextArea("You are a helpful yet diligent AI assistant. Answer faithfully and factually correct. Respond with 'I do not know' if uncertain.", label="System/Developer Prompt", lines=3, max_lines=250, elem_id="system_prompt")
 
13
  from gradio.components.base import Component
14
  from types import SimpleNamespace
15
  from dotenv import load_dotenv
16
+ from transcription import stream_transcriptions
17
 
18
  from doc2json import process_docx
19
  from code_exec import eval_script
 
216
  api_key=oai_key
217
  )
218
 
219
+ if model in ("whisper", "gpt-4o-transcribe-diarize"):
220
+ assistant_msg = gr.ChatMessage(role="assistant", content="")
221
+ streamed = False
222
+ for content in stream_transcriptions(client, model, message, history, system_prompt):
223
+ streamed = True
224
+ assistant_msg.content = content
225
+ yield assistant_msg, history_openai_format
226
+ if not streamed:
227
+ yield assistant_msg, history_openai_format
228
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  elif model == "gpt-image-1":
231
  if message.get("files"):
 
703
 
704
  oai_key = gr.Textbox(label="OpenAI API Key", elem_id="oai_key", value=os.environ.get("OPENAI_API_KEY"))
705
  model = gr.Dropdown(label="Model", value="gpt-5-mini", allow_custom_value=True, elem_id="model",
706
+ choices=["gpt-5", "gpt-5-mini", "gpt-5-chat-latest", "gpt-5-pro", "gpt-4o", "gpt-4.1", "o3", "o3-pro", "o4-mini", "chatgpt-4o-latest", "gpt-4o-mini", "gpt-4-turbo", "whisper", "gpt-4o-transcribe-diarize", "gpt-image-1"])
707
  reasoning_effort = gr.Dropdown(label="Reasoning Effort", value="medium", choices=["low", "medium", "high"], elem_id="reasoning_effort")
708
  verbosity = gr.Dropdown(label="Verbosity (GPT-5)", value="medium", choices=["low", "medium", "high"], elem_id="verbosity")
709
  system_prompt = gr.TextArea("You are a helpful yet diligent AI assistant. Answer faithfully and factually correct. Respond with 'I do not know' if uncertain.", label="System/Developer Prompt", lines=3, max_lines=250, elem_id="system_prompt")
transcription.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import Any, Iterable, Iterator, Sequence
6
+
7
+
8
+ MODEL_CONFIG = {
9
+ "whisper": {
10
+ "api_model": "whisper-1",
11
+ "response_format": "text",
12
+ "use_prompt": True,
13
+ "chunking_strategy": None,
14
+ "supports_stream": False,
15
+ },
16
+ "gpt-4o-transcribe-diarize": {
17
+ "api_model": "gpt-4o-transcribe-diarize",
18
+ "response_format": "diarized_json",
19
+ "use_prompt": False,
20
+ "chunking_strategy": "auto",
21
+ "supports_stream": True,
22
+ },
23
+ }
24
+
25
+
26
+ @dataclass
27
+ class TranscriptionUpdate:
28
+ text: str
29
+ is_final: bool
30
+ prompt_append: str | None
31
+
32
+
33
+ def stream_transcriptions(
34
+ client: Any,
35
+ model_key: str,
36
+ message: Any,
37
+ history: Iterable[Any],
38
+ system_prompt: str,
39
+ ) -> Iterator[str]:
40
+ if model_key not in MODEL_CONFIG:
41
+ raise ValueError(f"Unsupported transcription model: {model_key}")
42
+
43
+ config = MODEL_CONFIG[model_key]
44
+ prompt = _build_prompt(history, system_prompt)
45
+ message_text, files = _message_fields(message)
46
+ if config["use_prompt"] and message_text:
47
+ prompt += message_text
48
+
49
+ if not files:
50
+ return
51
+
52
+ completed: list[tuple[str, str]] = []
53
+ last_payload: str | None = None
54
+
55
+ for file in files:
56
+ audio_path = _field(file, "path")
57
+ if not audio_path:
58
+ if isinstance(file, str):
59
+ audio_path = file
60
+ if not audio_path:
61
+ continue
62
+
63
+ filename = os.path.basename(audio_path)
64
+ builder = _TranscriptBuilder(model_key)
65
+
66
+ for update in _stream_single_file(client, config, prompt, audio_path, builder):
67
+ if not update.text:
68
+ continue
69
+ current_blocks = completed + [(filename, update.text)]
70
+ payload = _assemble_transcript(current_blocks)
71
+ if payload != last_payload:
72
+ last_payload = payload
73
+ yield payload
74
+ if update.is_final:
75
+ completed.append((filename, update.text))
76
+ if config["use_prompt"] and update.prompt_append:
77
+ prompt = prompt + f"\n{update.prompt_append}"
78
+ break
79
+ else:
80
+ final_text = builder.formatted_text()
81
+ if final_text:
82
+ completed.append((filename, final_text))
83
+ payload = _assemble_transcript(completed)
84
+ if payload != last_payload:
85
+ yield payload
86
+ if config["use_prompt"]:
87
+ prompt = prompt + f"\n{final_text}"
88
+
89
+
90
+ def _stream_single_file(
91
+ client: Any,
92
+ config: dict[str, Any],
93
+ prompt: str,
94
+ audio_path: str,
95
+ builder: "_TranscriptBuilder",
96
+ ) -> Iterator[TranscriptionUpdate]:
97
+ request_kwargs = {
98
+ "model": config["api_model"],
99
+ "response_format": config["response_format"],
100
+ }
101
+ if config["use_prompt"]:
102
+ request_kwargs["prompt"] = prompt
103
+ if config["chunking_strategy"]:
104
+ request_kwargs["chunking_strategy"] = config["chunking_strategy"]
105
+
106
+ with open(audio_path, "rb") as fh:
107
+ request_kwargs["file"] = fh
108
+ response = client.audio.transcriptions.create(
109
+ stream=config["supports_stream"], **request_kwargs
110
+ )
111
+
112
+ if config["supports_stream"]:
113
+ yield from builder.consume_iter(response)
114
+ else:
115
+ yield builder.consume_snapshot(response)
116
+
117
+
118
+ def _assemble_transcript(blocks: Sequence[tuple[str, str]]) -> str:
119
+ parts = []
120
+ for filename, text in blocks:
121
+ body = text.rstrip()
122
+ parts.append(f"``` transcript {filename}\n{body}\n```")
123
+ return "\n".join(parts)
124
+
125
+
126
+ def _build_prompt(history: Iterable[Any], system_prompt: str) -> str:
127
+ prompt = system_prompt or ""
128
+ for msg in history or []:
129
+ role = _field(msg, "role")
130
+ if role not in ("user", "assistant"):
131
+ continue
132
+ content = _field(msg, "content")
133
+ if isinstance(content, tuple) or content is None:
134
+ continue
135
+ prompt += f"\n{content}"
136
+ return prompt
137
+
138
+
139
+ def _format_transcription_text(
140
+ model_key: str, text: str | None, segments: Sequence[dict[str, Any]] | None
141
+ ) -> str:
142
+ if model_key == "whisper":
143
+ return text or ""
144
+
145
+ if model_key == "gpt-4o-transcribe-diarize":
146
+ if segments:
147
+ turns: list[str] = []
148
+ prev_speaker: str | None = None
149
+ for seg in segments:
150
+ speaker = (seg.get("speaker") or "Speaker").strip()
151
+ seg_text = (seg.get("text") or "").strip()
152
+ if seg_text:
153
+ if speaker != prev_speaker or not turns:
154
+ turns.append(f"{speaker}: {seg_text}")
155
+ else:
156
+ turns[-1] = f"{turns[-1]} {seg_text}".strip()
157
+ prev_speaker = speaker
158
+ if turns:
159
+ return "\n\n".join(turns)
160
+ return text or ""
161
+
162
+ raise ValueError(f"Unhandled transcription model formatting: {model_key}")
163
+
164
+
165
+ class _TranscriptBuilder:
166
+ def __init__(self, model_key: str) -> None:
167
+ self.model_key = model_key
168
+ self._text_chunks: list[str] = []
169
+ self._segments: list[dict[str, Any]] = []
170
+ self._segment_ids: set[Any] = set()
171
+ self._final_text: str | None = None
172
+ self._final_segments: list[dict[str, Any]] | None = None
173
+ self._finalized = False
174
+
175
+ def consume_iter(self, stream: Any) -> Iterator[TranscriptionUpdate]:
176
+ if not hasattr(stream, "__iter__"):
177
+ yield self.consume_snapshot(stream)
178
+ return
179
+
180
+ for event in stream:
181
+ update = self._ingest(event, assume_final=False)
182
+ if update is None:
183
+ continue
184
+ yield update
185
+
186
+ if not self._finalized:
187
+ yield self._finalize_update()
188
+
189
+ def consume_snapshot(self, snapshot: Any) -> TranscriptionUpdate:
190
+ update = self._ingest(snapshot, assume_final=True)
191
+ if update is not None:
192
+ return update
193
+ return self._finalize_update()
194
+
195
+ def formatted_text(self) -> str:
196
+ text = self._final_text
197
+ if text is None and self._text_chunks:
198
+ text = "".join(self._text_chunks)
199
+ segments = self._final_segments or self._segments
200
+ return _format_transcription_text(self.model_key, text, segments)
201
+
202
+ def _ingest(self, obj: Any, assume_final: bool) -> TranscriptionUpdate | None:
203
+ data = _to_dict(obj)
204
+ changed, is_final = self._apply_event(data, assume_final=assume_final)
205
+ if not changed:
206
+ return None
207
+ formatted = self.formatted_text()
208
+ append = formatted if (is_final or assume_final) and formatted else None
209
+ return TranscriptionUpdate(text=formatted, is_final=is_final or assume_final, prompt_append=append)
210
+
211
+ def _apply_event(self, data: dict[str, Any], assume_final: bool) -> tuple[bool, bool]:
212
+ event_type = data.get("type")
213
+ changed = False
214
+ is_final = False
215
+
216
+ if event_type == "transcript.text.delta":
217
+ delta = data.get("delta")
218
+ if isinstance(delta, str) and delta:
219
+ self._text_chunks.append(delta)
220
+ changed = True
221
+
222
+ if event_type == "transcript.text.segment":
223
+ segment_payload = data.get("segment") or data
224
+ segment = _normalize_segment(segment_payload)
225
+ seg_id = segment.get("id")
226
+ if seg_id is None or seg_id not in self._segment_ids:
227
+ if seg_id is not None:
228
+ self._segment_ids.add(seg_id)
229
+ self._segments.append(segment)
230
+ changed = True
231
+
232
+ if event_type == "transcript.text.done":
233
+ self._capture_final(data)
234
+ changed = True
235
+ is_final = True
236
+
237
+ if not changed:
238
+ text_value = data.get("text")
239
+ segments_value = data.get("segments")
240
+ if isinstance(text_value, str) and text_value:
241
+ self._final_text = text_value
242
+ changed = True
243
+ if isinstance(segments_value, list) and segments_value:
244
+ self._final_segments = [_normalize_segment(seg) for seg in segments_value]
245
+ changed = True
246
+ if changed:
247
+ is_final = True
248
+
249
+ if assume_final:
250
+ is_final = True
251
+ if is_final:
252
+ self._finalized = True
253
+
254
+ return changed, is_final
255
+
256
+ def _capture_final(self, data: dict[str, Any]) -> None:
257
+ text_value = data.get("text")
258
+ if isinstance(text_value, str) and text_value:
259
+ self._final_text = text_value
260
+ segments_value = data.get("segments")
261
+ if isinstance(segments_value, list) and segments_value:
262
+ self._final_segments = [_normalize_segment(seg) for seg in segments_value]
263
+
264
+ def _finalize_update(self) -> TranscriptionUpdate:
265
+ formatted = self.formatted_text()
266
+ self._finalized = True
267
+ append = formatted if formatted else None
268
+ return TranscriptionUpdate(text=formatted, is_final=True, prompt_append=append)
269
+
270
+
271
+ def _normalize_segment(segment: Any) -> dict[str, Any]:
272
+ data = _to_dict(segment)
273
+ speaker = data.get("speaker")
274
+ if isinstance(speaker, str):
275
+ data["speaker"] = speaker.strip()
276
+ text = data.get("text")
277
+ if isinstance(text, str):
278
+ data["text"] = text.strip()
279
+ return data
280
+
281
+
282
+ def _message_fields(message: Any) -> tuple[str | None, Any]:
283
+ return _field(message, "text"), _field(message, "files")
284
+
285
+
286
+ def _field(obj: Any, key: str) -> Any:
287
+ if isinstance(obj, dict):
288
+ return obj.get(key)
289
+ return getattr(obj, key, None)
290
+
291
+
292
+ def _to_dict(obj: Any) -> dict[str, Any]:
293
+ if isinstance(obj, dict):
294
+ return obj
295
+ if isinstance(obj, str):
296
+ return {"text": obj}
297
+ if hasattr(obj, "model_dump"):
298
+ try:
299
+ return obj.model_dump()
300
+ except Exception:
301
+ pass
302
+ if hasattr(obj, "to_dict"):
303
+ try:
304
+ return obj.to_dict()
305
+ except Exception:
306
+ pass
307
+ if hasattr(obj, "__dict__"):
308
+ return {k: v for k, v in obj.__dict__.items() if not k.startswith("_")}
309
+ return {}