mimo_audio_chat / api_schema.py
Corle-heyongzhe's picture
init commit
c760a78
raw
history blame
4.95 kB
from abc import ABC
from io import BytesIO
from typing import Literal
import numpy as np
from pydantic import BaseModel, ConfigDict
class AbortController(ABC):
def is_alive(self) -> bool:
raise NotImplementedError
class NeverAbortedController(AbortController):
def is_alive(self) -> bool:
return True
def is_none_or_alive(abort_controller: AbortController | None) -> bool:
return abort_controller is None or abort_controller.is_alive()
class ModelNameResponse(BaseModel):
model_name: str
class TokenizedMessage(BaseModel):
role: Literal["user", "assistant"]
content: list[list[int]]
"""[audio_channels+1, time_steps]"""
def time_steps(self) -> int:
return len(self.content[0])
def append(self, chunk: list[list[int]]):
assert len(chunk) == len(self.content), "Incompatible chunk length"
assert all(len(c) == len(chunk[0]) for c in chunk), "Incompatible chunk shape"
for content_channel, chunk_channel in zip(self.content, chunk):
content_channel.extend(chunk_channel)
class TokenizedConversation(BaseModel):
messages: list[TokenizedMessage]
def time_steps(self) -> int:
return sum(msg.time_steps() for msg in self.messages)
def latest_messages(self, max_time_steps: int) -> "list[TokenizedMessage]":
sum_time_steps = 0
selected_messages: list[TokenizedMessage] = []
for msg in reversed(self.messages):
cur_time_steps = msg.time_steps()
if sum_time_steps + cur_time_steps > max_time_steps:
break
sum_time_steps += cur_time_steps
selected_messages.append(msg)
return list(reversed(selected_messages))
class ChatAudioBytes(BaseModel):
model_config = ConfigDict(ser_json_bytes="base64", val_json_bytes="base64")
sample_rate: int
audio_data: bytes
"""
shape = (channels, samples) or (samples,);
dtype = int16 or float32
"""
@classmethod
def from_audio(cls, audio: tuple[int, np.ndarray]) -> "ChatAudioBytes":
buf = BytesIO()
np.save(buf, audio[1])
return ChatAudioBytes(sample_rate=audio[0], audio_data=buf.getvalue())
def to_audio(self) -> tuple[int, np.ndarray]:
buf = BytesIO(self.audio_data)
audio_np = np.load(buf)
return self.sample_rate, audio_np
class ChatResponseItem(BaseModel):
tokenized_input: TokenizedMessage | None = None
token_chunk: list[list[int]] | None = None
"""[audio_channels+1, time_steps]"""
text_chunk: str | None = None
audio_chunk: ChatAudioBytes | None = None
end_of_stream: bool | None = None
"""Represent Special token <|eostm|>"""
end_of_transcription: bool | None = None
"""Represent Special token <|eot|> (not <|endoftext|>)"""
stop_reason: str | None = None
"""The reason why the generation is stopped, e.g., max_new_tokens, max_length, stop_token, aborted"""
class AssistantStyle(BaseModel):
preset_character: str | None = None
custom_character_prompt: str | None = None
preset_voice: str | None = None
custom_voice: ChatAudioBytes | None = None
class SamplerConfig(BaseModel):
"""
Sampling configuration for text/audio generation.
- If some fields are not set, their effects are disabled.
- If the entire config is not set (e.g., `global_sampler_config=None`), all fields are automatically determined.
- Use `temperature=0.0`/`top_k=1`/`top_p=0.0` instead of `do_sample=False` to disable sampling.
"""
temperature: float | None = None
top_k: int | None = None
top_p: float | None = None
def normalized(self) -> tuple[float, int, float]:
"""
Returns:
A tuple (temperature, top_k, top_p) with normalized values.
"""
if (
(self.temperature is not None and self.temperature <= 0.0)
or (self.top_k is not None and self.top_k <= 1)
or (self.top_p is not None and self.top_p <= 0.0)
):
return (1.0, 1, 1.0)
def default_clip[T: int | float](
value: T | None, default_value: T, min_value: T, max_value: T
) -> T:
if value is None:
return default_value
return max(min(value, max_value), min_value)
temperature = default_clip(self.temperature, 1.0, 0.01, 2.0)
top_k = default_clip(self.top_k, 1_000_000, 1, 1_000_000)
top_p = default_clip(self.top_p, 1.0, 0.01, 1.0)
return (temperature, top_k, top_p)
class ChatRequestBody(BaseModel):
conversation: TokenizedConversation | None = None
input_text: str | None = None
input_audio: ChatAudioBytes | None = None
assistant_style: AssistantStyle | None = None
global_sampler_config: SamplerConfig | None = None
local_sampler_config: SamplerConfig | None = None
class PresetOptions(BaseModel):
options: list[str]