Fedir Zadniprovskyi commited on
Commit
ed63b7d
·
1 Parent(s): b8804a6

fix: `timestamp_granularities[]` handling (#28, #58, #81)

Browse files
src/faster_whisper_server/routers/stt.py CHANGED
@@ -9,6 +9,7 @@ from fastapi import (
9
  APIRouter,
10
  Form,
11
  Query,
 
12
  Response,
13
  UploadFile,
14
  WebSocket,
@@ -30,6 +31,8 @@ from faster_whisper_server.config import (
30
  from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
31
  from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
32
  from faster_whisper_server.server_models import (
 
 
33
  TimestampGranularities,
34
  TranscriptionJsonResponse,
35
  TranscriptionVerboseJsonResponse,
@@ -150,6 +153,18 @@ def translate_file(
150
  return segments_to_response(segments, transcription_info, response_format)
151
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  # https://platform.openai.com/docs/api-reference/audio/createTranscription
154
  # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
155
  @router.post(
@@ -159,6 +174,7 @@ def translate_file(
159
  def transcribe_file(
160
  config: ConfigDependency,
161
  model_manager: ModelManagerDependency,
 
162
  file: Annotated[UploadFile, Form()],
163
  model: Annotated[ModelName | None, Form()] = None,
164
  language: Annotated[Language | None, Form()] = None,
@@ -167,6 +183,7 @@ def transcribe_file(
167
  temperature: Annotated[float, Form()] = 0.0,
168
  timestamp_granularities: Annotated[
169
  TimestampGranularities,
 
170
  Form(alias="timestamp_granularities[]"),
171
  ] = ["segment"],
172
  stream: Annotated[bool, Form()] = False,
@@ -178,6 +195,11 @@ def transcribe_file(
178
  language = config.default_language
179
  if response_format is None:
180
  response_format = config.default_response_format
 
 
 
 
 
181
  whisper = model_manager.load_model(model)
182
  segments, transcription_info = whisper.transcribe(
183
  file.file,
 
9
  APIRouter,
10
  Form,
11
  Query,
12
+ Request,
13
  Response,
14
  UploadFile,
15
  WebSocket,
 
31
  from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
32
  from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
33
  from faster_whisper_server.server_models import (
34
+ DEFAULT_TIMESTAMP_GRANULARITIES,
35
+ TIMESTAMP_GRANULARITIES_COMBINATIONS,
36
  TimestampGranularities,
37
  TranscriptionJsonResponse,
38
  TranscriptionVerboseJsonResponse,
 
153
  return segments_to_response(segments, transcription_info, response_format)
154
 
155
 
156
+ # HACK: Since Form() doesn't support `alias`, we need to use a workaround.
157
+ async def get_timestamp_granularities(request: Request) -> TimestampGranularities:
158
+ form = await request.form()
159
+ if form.get("timestamp_granularities[]") is None:
160
+ return DEFAULT_TIMESTAMP_GRANULARITIES
161
+ timestamp_granularities = form.getlist("timestamp_granularities[]")
162
+ assert (
163
+ timestamp_granularities in TIMESTAMP_GRANULARITIES_COMBINATIONS
164
+ ), f"{timestamp_granularities} is not a valid value for `timestamp_granularities[]`."
165
+ return timestamp_granularities
166
+
167
+
168
  # https://platform.openai.com/docs/api-reference/audio/createTranscription
169
  # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
170
  @router.post(
 
174
  def transcribe_file(
175
  config: ConfigDependency,
176
  model_manager: ModelManagerDependency,
177
+ request: Request,
178
  file: Annotated[UploadFile, Form()],
179
  model: Annotated[ModelName | None, Form()] = None,
180
  language: Annotated[Language | None, Form()] = None,
 
183
  temperature: Annotated[float, Form()] = 0.0,
184
  timestamp_granularities: Annotated[
185
  TimestampGranularities,
186
+ # WARN: `alias` doesn't actually work.
187
  Form(alias="timestamp_granularities[]"),
188
  ] = ["segment"],
189
  stream: Annotated[bool, Form()] = False,
 
195
  language = config.default_language
196
  if response_format is None:
197
  response_format = config.default_response_format
198
+ timestamp_granularities = asyncio.run(get_timestamp_granularities(request))
199
+ if timestamp_granularities != DEFAULT_TIMESTAMP_GRANULARITIES and response_format != ResponseFormat.VERBOSE_JSON:
200
+ logger.warning(
201
+ "It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501
202
+ )
203
  whisper = model_manager.load_model(model)
204
  segments, transcription_info = whisper.transcribe(
205
  file.file,
src/faster_whisper_server/server_models.py CHANGED
@@ -29,7 +29,7 @@ class TranscriptionVerboseJsonResponse(BaseModel):
29
  language: str
30
  duration: float
31
  text: str
32
- words: list[Word]
33
  segments: list[Segment]
34
 
35
  @classmethod
@@ -38,7 +38,7 @@ class TranscriptionVerboseJsonResponse(BaseModel):
38
  language=transcription_info.language,
39
  duration=segment.end - segment.start,
40
  text=segment.text,
41
- words=(segment.words if isinstance(segment.words, list) else []),
42
  segments=[segment],
43
  )
44
 
@@ -51,7 +51,7 @@ class TranscriptionVerboseJsonResponse(BaseModel):
51
  duration=transcription_info.duration,
52
  text=segments_to_text(segments),
53
  segments=segments,
54
- words=Word.from_segments(segments),
55
  )
56
 
57
  @classmethod
@@ -112,6 +112,7 @@ class ModelObject(BaseModel):
112
  TimestampGranularities = list[Literal["segment", "word"]]
113
 
114
 
 
115
  TIMESTAMP_GRANULARITIES_COMBINATIONS: list[TimestampGranularities] = [
116
  [], # should be treated as ["segment"]. https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities
117
  ["segment"],
 
29
  language: str
30
  duration: float
31
  text: str
32
+ words: list[Word] | None
33
  segments: list[Segment]
34
 
35
  @classmethod
 
38
  language=transcription_info.language,
39
  duration=segment.end - segment.start,
40
  text=segment.text,
41
+ words=segment.words if transcription_info.transcription_options.word_timestamps else None,
42
  segments=[segment],
43
  )
44
 
 
51
  duration=transcription_info.duration,
52
  text=segments_to_text(segments),
53
  segments=segments,
54
+ words=Word.from_segments(segments) if transcription_info.transcription_options.word_timestamps else None,
55
  )
56
 
57
  @classmethod
 
112
  TimestampGranularities = list[Literal["segment", "word"]]
113
 
114
 
115
+ DEFAULT_TIMESTAMP_GRANULARITIES: TimestampGranularities = ["segment"]
116
  TIMESTAMP_GRANULARITIES_COMBINATIONS: list[TimestampGranularities] = [
117
  [], # should be treated as ["segment"]. https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities
118
  ["segment"],
tests/api_timestamp_granularities_test.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """See `tests/openai_timestamp_granularities_test.py` to understand how OpenAI handles `response_type` and `timestamp_granularities`.""" # noqa: E501
2
+
3
+ from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
4
+ from openai import AsyncOpenAI
5
+ import pytest
6
+
7
+
8
+ @pytest.mark.asyncio()
9
+ @pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS)
10
+ async def test_api_json_response_format_and_timestamp_granularities_combinations(
11
+ openai_client: AsyncOpenAI,
12
+ timestamp_granularities: TimestampGranularities,
13
+ ) -> None:
14
+ audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230
15
+
16
+ await openai_client.audio.transcriptions.create(
17
+ file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities
18
+ )
19
+
20
+
21
+ @pytest.mark.asyncio()
22
+ @pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS)
23
+ async def test_api_verbose_json_response_format_and_timestamp_granularities_combinations(
24
+ openai_client: AsyncOpenAI,
25
+ timestamp_granularities: TimestampGranularities,
26
+ ) -> None:
27
+ audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230
28
+
29
+ transcription = await openai_client.audio.transcriptions.create(
30
+ file=audio_file,
31
+ model="whisper-1",
32
+ response_format="verbose_json",
33
+ timestamp_granularities=timestamp_granularities,
34
+ )
35
+
36
+ assert transcription.__pydantic_extra__
37
+ if "word" in timestamp_granularities:
38
+ assert transcription.__pydantic_extra__.get("segments") is not None
39
+ assert transcription.__pydantic_extra__.get("words") is not None
40
+ else:
41
+ # Unless explicitly requested, words are not present
42
+ assert transcription.__pydantic_extra__.get("segments") is not None
43
+ assert transcription.__pydantic_extra__.get("words") is None