import torch import numpy as np import torchaudio import sentencepiece import logging from pathlib import Path from moshi.models import loaders, LMGen logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class InferenceRecipe: """Handles model inference for the Any-to-Any model.""" def __init__(self, model_path: str, device: str='cuda'): """Initialize the model. Args: model_path (str): Path to model directory with pre-downloaded files device (str): Device to run on ('cuda' or 'cpu') """ self.device = torch.device(device) self.model_path = Path(model_path) # Set sample rate and frame rate self.sample_rate = 24000 # Based on model config in loaders.py self.frame_rate = 12.5 # Based on model config in loaders.py # Initialize all model components logger.info(f"Initializing models from {model_path}") self.mimi, self.text_tokenizer, self.lm_gen = self._initialize_models() self.mimi = self.mimi.to(self.device) self.lm_gen = self.lm_gen.to(self.device) logger.info("Model initialization complete") def _initialize_models(self): """Initialize all required model components.""" print("Initializing models...") try: # Load MIMI model for encoding/decoding mimi_path = self.model_path / loaders.MIMI_NAME if not mimi_path.exists(): raise RuntimeError(f"MIMI model not found at {mimi_path}") logger.info(f"Loading MIMI model from {mimi_path}") mimi = loaders.get_mimi(str(mimi_path), device=self.device) mimi.set_num_codebooks(8) # Load text tokenizer tokenizer_path = self.model_path / loaders.TEXT_TOKENIZER_NAME if not tokenizer_path.exists(): raise RuntimeError(f"Text tokenizer not found at {tokenizer_path}") logger.info(f"Loading text tokenizer from {tokenizer_path}") text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer_path)) # Load language model moshi_path = self.model_path / loaders.MOSHI_NAME if not moshi_path.exists(): raise RuntimeError(f"Language model not found at {moshi_path}") logger.info(f"Loading language model from {moshi_path}") moshi = loaders.get_moshi_lm(str(moshi_path), device=self.device) lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) return mimi, text_tokenizer, lm_gen except Exception as e: logger.error(f"Model initialization failed: {str(e)}") raise def _load_audio(self, audio_array: np.ndarray, sample_rate: int): """Load and preprocess audio.""" try: # Convert to tensor wav = torch.from_numpy(audio_array).float().unsqueeze(0) # Resample if needed if sample_rate != self.sample_rate: logger.info(f"Resampling from {sample_rate} to {self.sample_rate}") # Create resampler on same device as input will be resampler = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=self.sample_rate ).to(self.device) # Move wav to device before resampling wav = resampler(wav.to(self.device)) else: # If no resampling needed, still ensure wav is on correct device wav = wav.to(self.device) # Ensure frame alignment frame_size = int(self.sample_rate / self.frame_rate) orig_length = wav.shape[-1] wav = wav[:, :, :(wav.shape[-1] // frame_size) * frame_size] if wav.shape[-1] != orig_length: logger.info(f"Trimmed audio from {orig_length} to {wav.shape[-1]} samples for frame alignment") return wav except Exception as e: logger.error(f"Audio loading failed: {str(e)}") raise def _pad_codes(self, all_codes, time_seconds=30): try: min_frames = int(time_seconds * self.frame_rate) frame_size = int(self.sample_rate / self.frame_rate) if len(all_codes) < min_frames: frames_to_add = min_frames - len(all_codes) logger.info(f"Padding {frames_to_add} frames to reach minimum length") with torch.no_grad(), self.mimi.streaming(batch_size=1): # Create tensor on the correct device chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device) for _ in range(frames_to_add): additional_code = self.mimi.encode(chunk) all_codes.append(additional_code) return all_codes except Exception as e: logger.error(f"Code padding failed: {str(e)}") raise def _encode_audio(self, wav: torch.Tensor): """Convert audio to codes.""" try: frame_size = int(self.sample_rate / self.frame_rate) all_codes = [] with torch.no_grad(), self.mimi.streaming(batch_size=1): for offset in range(0, wav.shape[-1], frame_size): frame = wav[:, :, offset: offset + frame_size] codes = self.mimi.encode(frame.to(self.device)) assert codes.shape[-1] == 1, f"Expected code shape (*, *, 1), got {codes.shape}" all_codes.append(codes) logger.info(f"Encoded {len(all_codes)} frames") return all_codes except Exception as e: logger.error(f"Audio encoding failed: {str(e)}") raise def _warmup(self): """Run a warmup pass.""" try: frame_size = int(self.sample_rate / self.frame_rate) # Create tensor on the correct device from the start chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device) with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1): codes = self.mimi.encode(chunk) # chunk already on correct device tokens = self.lm_gen.step(codes[:, :, 0:1]) if tokens is not None: _ = self.mimi.decode(tokens[:, 1:]) if self.device.type == 'cuda': torch.cuda.synchronize() logger.info("Warmup pass completed") except Exception as e: logger.error(f"Warmup failed: {str(e)}") raise def _generate(self, all_codes): """Generate audio and text from codes.""" try: out_wav_chunks = [] text_output = [] with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1): for i, code in enumerate(all_codes): assert code.shape == (1, 8, 1), f"Expected code shape (1, 8, 1), got {code.shape}" tokens_out = self.lm_gen.step(code.to(self.device)) if tokens_out is not None: # Generate audio wav_chunk = self.mimi.decode(tokens_out[:, 1:]) out_wav_chunks.append(wav_chunk) # Generate text if available text_token = tokens_out[0, 0, 0].item() if text_token not in (0, 3): _text = self.text_tokenizer.id_to_piece(text_token) _text = _text.replace("▁", " ") text_output.append(_text) if (i + 1) % 100 == 0: logger.info(f"Processed {i + 1}/{len(all_codes)} frames") wav = torch.cat(out_wav_chunks, dim=-1) text = ''.join(text_output) logger.info(f"Generated {wav.shape[-1]} samples of audio and {len(text)} characters of text") return wav, text except Exception as e: logger.error(f"Generation failed: {str(e)}") raise def inference(self, audio_array: np.ndarray, sample_rate: int) -> dict: """Run inference on input audio. Args: audio_array (np.ndarray): Input audio as numpy array sample_rate (int): Sample rate of input audio Returns: dict: Contains generated audio array and optional transcribed text """ try: logger.info(f"Starting inference on {len(audio_array)} samples at {sample_rate} Hz, self device: {self.device}") # Load and preprocess audio wav = self._load_audio(audio_array, sample_rate) wav = wav.to(self.device) # Convert to codes all_codes = self._encode_audio(wav) all_codes = self._pad_codes(all_codes) # Warmup pass self._warmup() # Generate output out_wav, text = self._generate(all_codes) # Convert output to numpy output = out_wav.cpu().numpy().squeeze() logger.info("Inference completed successfully") return { "audio": output, "text": text } except Exception as e: logger.error(f"Inference failed: {str(e)}") raise if __name__ == "__main__": # Example usage import librosa # Initialize model model = InferenceRecipe("/path/to/models", device="cuda") # Load test audio audio, sr = librosa.load("test.wav", sr=None) # Run inference result = model.inference(audio, sr) print(f"Generated {len(result['audio'])} samples of audio") print(f"Generated text: {result['text']}")