moshi_general / inference.py
tezuesh's picture
Update inference.py
5acce69 verified
import torch
import numpy as np
import torchaudio
import sentencepiece
import logging
from pathlib import Path
from moshi.models import loaders, LMGen
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class InferenceRecipe:
"""Handles model inference for the Any-to-Any model."""
def __init__(self, model_path: str, device: str='cuda'):
"""Initialize the model.
Args:
model_path (str): Path to model directory with pre-downloaded files
device (str): Device to run on ('cuda' or 'cpu')
"""
self.device = torch.device(device)
self.model_path = Path(model_path)
# Set sample rate and frame rate
self.sample_rate = 24000 # Based on model config in loaders.py
self.frame_rate = 12.5 # Based on model config in loaders.py
# Initialize all model components
logger.info(f"Initializing models from {model_path}")
self.mimi, self.text_tokenizer, self.lm_gen = self._initialize_models()
self.mimi = self.mimi.to(self.device)
self.lm_gen = self.lm_gen.to(self.device)
logger.info("Model initialization complete")
def _initialize_models(self):
"""Initialize all required model components."""
print("Initializing models...")
try:
# Load MIMI model for encoding/decoding
mimi_path = self.model_path / loaders.MIMI_NAME
if not mimi_path.exists():
raise RuntimeError(f"MIMI model not found at {mimi_path}")
logger.info(f"Loading MIMI model from {mimi_path}")
mimi = loaders.get_mimi(str(mimi_path), device=self.device)
mimi.set_num_codebooks(8)
# Load text tokenizer
tokenizer_path = self.model_path / loaders.TEXT_TOKENIZER_NAME
if not tokenizer_path.exists():
raise RuntimeError(f"Text tokenizer not found at {tokenizer_path}")
logger.info(f"Loading text tokenizer from {tokenizer_path}")
text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer_path))
# Load language model
moshi_path = self.model_path / loaders.MOSHI_NAME
if not moshi_path.exists():
raise RuntimeError(f"Language model not found at {moshi_path}")
logger.info(f"Loading language model from {moshi_path}")
moshi = loaders.get_moshi_lm(str(moshi_path), device=self.device)
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)
return mimi, text_tokenizer, lm_gen
except Exception as e:
logger.error(f"Model initialization failed: {str(e)}")
raise
def _load_audio(self, audio_array: np.ndarray, sample_rate: int):
"""Load and preprocess audio."""
try:
# Convert to tensor
wav = torch.from_numpy(audio_array).float().unsqueeze(0)
# Resample if needed
if sample_rate != self.sample_rate:
logger.info(f"Resampling from {sample_rate} to {self.sample_rate}")
# Create resampler on same device as input will be
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate,
new_freq=self.sample_rate
).to(self.device)
# Move wav to device before resampling
wav = resampler(wav.to(self.device))
else:
# If no resampling needed, still ensure wav is on correct device
wav = wav.to(self.device)
# Ensure frame alignment
frame_size = int(self.sample_rate / self.frame_rate)
orig_length = wav.shape[-1]
wav = wav[:, :, :(wav.shape[-1] // frame_size) * frame_size]
if wav.shape[-1] != orig_length:
logger.info(f"Trimmed audio from {orig_length} to {wav.shape[-1]} samples for frame alignment")
return wav
except Exception as e:
logger.error(f"Audio loading failed: {str(e)}")
raise
def _pad_codes(self, all_codes, time_seconds=30):
try:
min_frames = int(time_seconds * self.frame_rate)
frame_size = int(self.sample_rate / self.frame_rate)
if len(all_codes) < min_frames:
frames_to_add = min_frames - len(all_codes)
logger.info(f"Padding {frames_to_add} frames to reach minimum length")
with torch.no_grad(), self.mimi.streaming(batch_size=1):
# Create tensor on the correct device
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
for _ in range(frames_to_add):
additional_code = self.mimi.encode(chunk)
all_codes.append(additional_code)
return all_codes
except Exception as e:
logger.error(f"Code padding failed: {str(e)}")
raise
def _encode_audio(self, wav: torch.Tensor):
"""Convert audio to codes."""
try:
frame_size = int(self.sample_rate / self.frame_rate)
all_codes = []
with torch.no_grad(), self.mimi.streaming(batch_size=1):
for offset in range(0, wav.shape[-1], frame_size):
frame = wav[:, :, offset: offset + frame_size]
codes = self.mimi.encode(frame.to(self.device))
assert codes.shape[-1] == 1, f"Expected code shape (*, *, 1), got {codes.shape}"
all_codes.append(codes)
logger.info(f"Encoded {len(all_codes)} frames")
return all_codes
except Exception as e:
logger.error(f"Audio encoding failed: {str(e)}")
raise
def _warmup(self):
"""Run a warmup pass."""
try:
frame_size = int(self.sample_rate / self.frame_rate)
# Create tensor on the correct device from the start
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1):
codes = self.mimi.encode(chunk) # chunk already on correct device
tokens = self.lm_gen.step(codes[:, :, 0:1])
if tokens is not None:
_ = self.mimi.decode(tokens[:, 1:])
if self.device.type == 'cuda':
torch.cuda.synchronize()
logger.info("Warmup pass completed")
except Exception as e:
logger.error(f"Warmup failed: {str(e)}")
raise
def _generate(self, all_codes):
"""Generate audio and text from codes."""
try:
out_wav_chunks = []
text_output = []
with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1):
for i, code in enumerate(all_codes):
assert code.shape == (1, 8, 1), f"Expected code shape (1, 8, 1), got {code.shape}"
tokens_out = self.lm_gen.step(code.to(self.device))
if tokens_out is not None:
# Generate audio
wav_chunk = self.mimi.decode(tokens_out[:, 1:])
out_wav_chunks.append(wav_chunk)
# Generate text if available
text_token = tokens_out[0, 0, 0].item()
if text_token not in (0, 3):
_text = self.text_tokenizer.id_to_piece(text_token)
_text = _text.replace("▁", " ")
text_output.append(_text)
if (i + 1) % 100 == 0:
logger.info(f"Processed {i + 1}/{len(all_codes)} frames")
wav = torch.cat(out_wav_chunks, dim=-1)
text = ''.join(text_output)
logger.info(f"Generated {wav.shape[-1]} samples of audio and {len(text)} characters of text")
return wav, text
except Exception as e:
logger.error(f"Generation failed: {str(e)}")
raise
def inference(self, audio_array: np.ndarray, sample_rate: int) -> dict:
"""Run inference on input audio.
Args:
audio_array (np.ndarray): Input audio as numpy array
sample_rate (int): Sample rate of input audio
Returns:
dict: Contains generated audio array and optional transcribed text
"""
try:
logger.info(f"Starting inference on {len(audio_array)} samples at {sample_rate} Hz, self device: {self.device}")
# Load and preprocess audio
wav = self._load_audio(audio_array, sample_rate)
wav = wav.to(self.device)
# Convert to codes
all_codes = self._encode_audio(wav)
all_codes = self._pad_codes(all_codes)
# Warmup pass
self._warmup()
# Generate output
out_wav, text = self._generate(all_codes)
# Convert output to numpy
output = out_wav.cpu().numpy().squeeze()
logger.info("Inference completed successfully")
return {
"audio": output,
"text": text
}
except Exception as e:
logger.error(f"Inference failed: {str(e)}")
raise
if __name__ == "__main__":
# Example usage
import librosa
# Initialize model
model = InferenceRecipe("/path/to/models", device="cuda")
# Load test audio
audio, sr = librosa.load("test.wav", sr=None)
# Run inference
result = model.inference(audio, sr)
print(f"Generated {len(result['audio'])} samples of audio")
print(f"Generated text: {result['text']}")