ai-audio-books / src /schemas.py
Aliaksandr
merge dev into main (#13)
f655f69 unverified
from __future__ import annotations
import base64
import typing as t
from enum import StrEnum
import pandas as pd
from elevenlabs import VoiceSettings
from pydantic import BaseModel, ConfigDict, Field
from src import utils
class AudioOutputFormat(StrEnum):
MP3_22050_32 = "mp3_22050_32"
MP3_44100_32 = "mp3_44100_32"
MP3_44100_64 = "mp3_44100_64"
MP3_44100_96 = "mp3_44100_96"
MP3_44100_128 = "mp3_44100_128"
MP3_44100_192 = "mp3_44100_192"
PCM_16000 = "pcm_16000"
PCM_22050 = "pcm_22050"
PCM_24000 = "pcm_24000"
PCM_44100 = "pcm_44100"
ULAW_8000 = "ulaw_8000"
class ExtraForbidModel(BaseModel):
model_config = ConfigDict(extra="forbid")
# use Ellipsis to mark omitted function parameter.
# cast it to Any type to avoid warnings from type checkers
# exact same approach is used in elevenlabs client.
OMIT = t.cast(t.Any, ...)
class TTSParams(ExtraForbidModel):
# NOTE: pydantic treats Ellipsis as a mark of a required field.
# in order to set Ellipsis as actual field default value, we need to use workaround
# and use Field's default_factory
voice_id: str
text: str
# enable_logging: typing.Optional[bool] = None
# NOTE: we opt for quality over speed - thus don't use this param
# optimize_streaming_latency: typing.Optional[OptimizeStreamingLatency] = None
# NOTE: here we set default different from 11labs API
# output_format: AudioOutputFormat = AudioOutputFormat.MP3_44100_128
output_format: AudioOutputFormat = AudioOutputFormat.MP3_44100_192
# NOTE: pydantic has protected "model_" namespace.
# here we use workaround to pass "model_id" param to 11labs client
# via serialization_alias
audio_model_id: t.Optional[str] = Field(
default_factory=lambda: OMIT, serialization_alias="model_id"
)
language_code: t.Optional[str] = Field(default_factory=lambda: OMIT)
# reference: https://elevenlabs.io/docs/speech-synthesis/voice-settings
voice_settings: t.Optional[VoiceSettings] = Field(default_factory=lambda: OMIT)
# pronunciation_dictionary_locators: t.Optional[
# t.Sequence[PronunciationDictionaryVersionLocator]
# ] = Field(default_factory=lambda: OMIT)
seed: t.Optional[int] = Field(default_factory=lambda: OMIT)
previous_text: t.Optional[str] = Field(default_factory=lambda: OMIT)
next_text: t.Optional[str] = Field(default_factory=lambda: OMIT)
previous_request_ids: t.Optional[t.Sequence[str]] = Field(default_factory=lambda: OMIT)
next_request_ids: t.Optional[t.Sequence[str]] = Field(default_factory=lambda: OMIT)
# request_options: t.Optional[RequestOptions] = None
def to_dict(self):
"""
dump the pydantic model in the format required by 11labs api.
NOTE: we need to use `by_alias=True` in order to correctly handle
alias for `audio_model_id` field,
since model_id belongs to pydantic protected namespace.
NOTE: we also ignore all fields with default Ellipsis value,
since 11labs will assign Ellipses itself,
and we won't get any warning in logs.
"""
ellipsis_fields = {field for field, value in self if value is ...}
res = self.model_dump(by_alias=True, exclude=ellipsis_fields)
return res
class TTSTimestampsAlignment(ExtraForbidModel):
characters: list[str]
character_start_times_seconds: list[float]
character_end_times_seconds: list[float]
_text_joined: str
def __init__(self, **data):
super().__init__(**data)
self._text_joined = "".join(self.characters)
@property
def text_joined(self):
return self._text_joined
def to_dataframe(self):
return pd.DataFrame(
{
"char": self.characters,
"start": self.character_start_times_seconds,
"end": self.character_end_times_seconds,
}
)
@classmethod
def combine_alignments(
cls,
alignments: list[TTSTimestampsAlignment],
add_placeholders: bool = False,
pause_bw_chunks_s: float = 0.2,
) -> TTSTimestampsAlignment:
"""
Combine alignemnts created for different TTS phrases in a single aligment for a whole text.
NOTE: while splitting original text into character phrases,
we ignore separators between phrases.
They may be different: single or multiple spaces, newlines, etc.
To account for them we insert fixed pause and characters between phrases in final alignment.
This will give use an approximation of a real timestamp mapping
for voicing a whole original text.
NOTE: The quality of such approximation seems appropriate,
considering the amount of time required to implement more accurate mapping.
"""
chars = []
starts = []
ends = []
prev_chunk_end_time = 0.0
n_alignments = len(alignments)
for ix, a in enumerate(alignments):
cur_starts_absolute = [prev_chunk_end_time + s for s in a.character_start_times_seconds]
cur_ends_absolute = [prev_chunk_end_time + e for e in a.character_end_times_seconds]
chars.extend(a.characters)
starts.extend(cur_starts_absolute)
ends.extend(cur_ends_absolute)
if ix < n_alignments - 1 and add_placeholders:
chars.append('#')
placeholder_start = cur_ends_absolute[-1]
starts.append(placeholder_start)
ends.append(placeholder_start + pause_bw_chunks_s)
prev_chunk_end_time = ends[-1]
return cls(
characters=chars,
character_start_times_seconds=starts,
character_end_times_seconds=ends,
)
def filter_chars_without_duration(self):
"""
Create new class instance with characters with 0 duration removed.
Needed to provide correct alignment when overlaying sound effects.
"""
df = self.to_dataframe()
mask = (df['start'] - df['end']).abs() > 1e-5
df = df[mask]
res = TTSTimestampsAlignment(
characters=df['char'].to_list(),
character_start_times_seconds=df['start'].to_list(),
character_end_times_seconds=df['end'].to_list(),
)
return res
def get_start_time_by_char_ix(self, char_ix: int, safe=True):
if safe:
char_ix = utils.get_collection_safe_index(
ix=char_ix,
collection=self.character_start_times_seconds,
)
return self.character_start_times_seconds[char_ix]
def get_end_time_by_char_ix(self, char_ix: int, safe=True):
if safe:
char_ix = utils.get_collection_safe_index(
ix=char_ix,
collection=self.character_end_times_seconds,
)
return self.character_end_times_seconds[char_ix]
class TTSTimestampsResponse(ExtraForbidModel):
audio_base64: str
alignment: TTSTimestampsAlignment
normalized_alignment: TTSTimestampsAlignment
@property
def audio_bytes(self):
return base64.b64decode(self.audio_base64)
def write_audio_to_file(self, filepath_no_ext: str, audio_format: AudioOutputFormat) -> str:
if audio_format.startswith("pcm_"):
sr = int(audio_format.removeprefix("pcm_"))
fp = f"{filepath_no_ext}.wav"
utils.write_raw_pcm_to_file(
data=self.audio_bytes,
fp=fp,
n_channels=1, # seems like it's 1 channel always
bytes_depth=2, # seems like it's 2 bytes always
sampling_rate=sr,
)
return fp
elif audio_format.startswith("mp3_"):
fp = f"{filepath_no_ext}.mp3"
# received mp3 seems to already contain all required metadata
# like sampling rate
# and sample width
utils.write_bytes(data=self.audio_bytes, fp=fp)
return fp
else:
raise ValueError(f"don't know how to write audio format: {audio_format}")
class SoundEffectsParams(ExtraForbidModel):
text: str
duration_seconds: float | None
prompt_influence: float | None