Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import torch | |
| from torch import nn | |
| from transformers import MimiModel | |
| DEFAULT_MIMI_MODEL_ID = "kyutai/mimi" | |
| class MimiConfig: | |
| model_id: str = DEFAULT_MIMI_MODEL_ID | |
| dtype: Optional[torch.dtype] = None | |
| class MimiCodec(nn.Module): | |
| """Thin wrapper around transformers' MimiModel for decoding audio tokens.""" | |
| def __init__(self, model: MimiModel, device: torch.device) -> None: | |
| super().__init__() | |
| self.model = model | |
| self.device = device | |
| cfg = getattr(model, "config", None) | |
| self.sample_rate = getattr(cfg, "sampling_rate", 24000) | |
| self.frame_rate = getattr(cfg, "frame_rate", 12.5) | |
| self.samples_per_frame = int(round(self.sample_rate / self.frame_rate)) if self.frame_rate else 0 | |
| def from_pretrained( | |
| cls, | |
| model_id: str = DEFAULT_MIMI_MODEL_ID, | |
| *, | |
| device: torch.device, | |
| dtype: Optional[torch.dtype] = None, | |
| ) -> "MimiCodec": | |
| model = MimiModel.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| ) | |
| model = model.to(device) | |
| model.eval() | |
| return cls(model, device) | |
| def decode(self, codes: torch.Tensor) -> torch.Tensor: | |
| codes = codes.to(self.device) | |
| with torch.inference_mode(): | |
| audio, _ = self.model.decode(codes, return_dict=False) | |
| return torch.clamp(audio, -1.0, 1.0) | |
| def encode(self, audio: torch.Tensor, *, return_dict: bool = False): | |
| audio = audio.to(self.device) | |
| with torch.inference_mode(): | |
| return self.model.encode(audio, return_dict=return_dict) | |