import dataclasses import logging from pathlib import Path from typing import Optional import torch from colorlog import ColoredFormatter from torchvision.transforms import v2 from torio.io import StreamingMediaDecoder, StreamingMediaEncoder from mmaudio.model.flow_matching import FlowMatching from mmaudio.model.networks import MMAudio from mmaudio.model.sequence_config import (CONFIG_16K, CONFIG_44K, SequenceConfig) from mmaudio.model.utils.features_utils import FeaturesUtils from mmaudio.utils.download_utils import download_model_if_needed log = logging.getLogger() @dataclasses.dataclass class ModelConfig: model_name: str model_path: Path vae_path: Path bigvgan_16k_path: Optional[Path] mode: str synchformer_ckpt: Path = Path('./ext_weights/synchformer_state_dict.pth') @property def seq_cfg(self) -> SequenceConfig: if self.mode == '16k': return CONFIG_16K elif self.mode == '44k': return CONFIG_44K def download_if_needed(self): download_model_if_needed(self.model_path) download_model_if_needed(self.vae_path) if self.bigvgan_16k_path is not None: download_model_if_needed(self.bigvgan_16k_path) download_model_if_needed(self.synchformer_ckpt) small_16k = ModelConfig(model_name='small_16k', model_path=Path('./weights/mmaudio_small_16k.pth'), vae_path=Path('./ext_weights/v1-16.pth'), bigvgan_16k_path=Path('./ext_weights/best_netG.pt'), mode='16k') small_44k = ModelConfig(model_name='small_44k', model_path=Path('./weights/mmaudio_small_44k.pth'), vae_path=Path('./ext_weights/v1-44.pth'), bigvgan_16k_path=None, mode='44k') medium_44k = ModelConfig(model_name='medium_44k', model_path=Path('./weights/mmaudio_medium_44k.pth'), vae_path=Path('./ext_weights/v1-44.pth'), bigvgan_16k_path=None, mode='44k') large_44k = ModelConfig(model_name='large_44k', model_path=Path('./weights/mmaudio_large_44k.pth'), vae_path=Path('./ext_weights/v1-44.pth'), bigvgan_16k_path=None, mode='44k') large_44k_v2 = ModelConfig(model_name='large_44k_v2', model_path=Path('./weights/mmaudio_large_44k_v2.pth'), vae_path=Path('./ext_weights/v1-44.pth'), bigvgan_16k_path=None, mode='44k') all_model_cfg: dict[str, ModelConfig] = { 'small_16k': small_16k, 'small_44k': small_44k, 'medium_44k': medium_44k, 'large_44k': large_44k, 'large_44k_v2': large_44k_v2, } def generate(clip_video: Optional[torch.Tensor], sync_video: Optional[torch.Tensor], text: Optional[list[str]], *, negative_text: Optional[list[str]] = None, feature_utils: FeaturesUtils, net: MMAudio, fm: FlowMatching, rng: torch.Generator, cfg_strength: float): device = feature_utils.device dtype = feature_utils.dtype bs = len(text) if clip_video is not None: clip_video = clip_video.to(device, dtype, non_blocking=True) clip_features = feature_utils.encode_video_with_clip(clip_video, batch_size=bs) else: clip_features = net.get_empty_clip_sequence(bs) if sync_video is not None: sync_video = sync_video.to(device, dtype, non_blocking=True) sync_features = feature_utils.encode_video_with_sync(sync_video, batch_size=bs) else: sync_features = net.get_empty_sync_sequence(bs) if text is not None: text_features = feature_utils.encode_text(text) else: text_features = net.get_empty_string_sequence(bs) if negative_text is not None: assert len(negative_text) == bs negative_text_features = feature_utils.encode_text(negative_text) else: negative_text_features = net.get_empty_string_sequence(bs) x0 = torch.randn(bs, net.latent_seq_len, net.latent_dim, device=device, dtype=dtype, generator=rng) preprocessed_conditions = net.preprocess_conditions(clip_features, sync_features, text_features) empty_conditions = net.get_empty_conditions( bs, negative_text_features=negative_text_features if negative_text is not None else None) cfg_ode_wrapper = lambda t, x: net.ode_wrapper(t, x, preprocessed_conditions, empty_conditions, cfg_strength) x1 = fm.to_data(cfg_ode_wrapper, x0) x1 = net.unnormalize(x1) spec = feature_utils.decode(x1) audio = feature_utils.vocode(spec) return audio LOGFORMAT = " %(log_color)s%(levelname)-8s%(reset)s | %(log_color)s%(message)s%(reset)s" def setup_eval_logging(log_level: int = logging.INFO): logging.root.setLevel(log_level) formatter = ColoredFormatter(LOGFORMAT) stream = logging.StreamHandler() stream.setLevel(log_level) stream.setFormatter(formatter) log = logging.getLogger() log.setLevel(log_level) log.addHandler(stream) def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, torch.Tensor, float]: _CLIP_SIZE = 384 _CLIP_FPS = 8.0 _SYNC_SIZE = 224 _SYNC_FPS = 25.0 clip_transform = v2.Compose([ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), ]) sync_transform = v2.Compose([ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), v2.CenterCrop(_SYNC_SIZE), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) reader = StreamingMediaDecoder(video_path) reader.add_basic_video_stream( frames_per_chunk=int(_CLIP_FPS * duration_sec), frame_rate=_CLIP_FPS, format='rgb24', ) reader.add_basic_video_stream( frames_per_chunk=int(_SYNC_FPS * duration_sec), frame_rate=_SYNC_FPS, format='rgb24', ) reader.fill_buffer() data_chunk = reader.pop_chunks() clip_chunk = data_chunk[0] sync_chunk = data_chunk[1] assert clip_chunk is not None assert sync_chunk is not None clip_frames = clip_transform(clip_chunk) sync_frames = sync_transform(sync_chunk) clip_length_sec = clip_frames.shape[0] / _CLIP_FPS sync_length_sec = sync_frames.shape[0] / _SYNC_FPS if clip_length_sec < duration_sec: log.warning(f'Clip video is too short: {clip_length_sec:.2f} < {duration_sec:.2f}') log.warning(f'Truncating to {clip_length_sec:.2f} sec') duration_sec = clip_length_sec if sync_length_sec < duration_sec: log.warning(f'Sync video is too short: {sync_length_sec:.2f} < {duration_sec:.2f}') log.warning(f'Truncating to {sync_length_sec:.2f} sec') duration_sec = sync_length_sec clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)] sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)] return clip_frames, sync_frames, duration_sec def make_video(video_path: Path, output_path: Path, audio: torch.Tensor, sampling_rate: int, duration_sec: float): approx_max_length = int(duration_sec * 60) reader = StreamingMediaDecoder(video_path) reader.add_basic_video_stream( frames_per_chunk=approx_max_length, format='rgb24', ) reader.fill_buffer() video_chunk = reader.pop_chunks()[0] assert video_chunk is not None fps = int(reader.get_out_stream_info(0).frame_rate) if fps > 60: log.warning(f'This code supports only up to 60 fps, but the video has {fps} fps') log.warning(f'Just change the *60 above me') h, w = video_chunk.shape[-2:] video_chunk = video_chunk[:int(fps * duration_sec)] writer = StreamingMediaEncoder(output_path) writer.add_audio_stream( sample_rate=sampling_rate, num_channels=audio.shape[0], encoder='aac', # 'flac' does not work for some reason? ) writer.add_video_stream(frame_rate=fps, width=w, height=h, format='rgb24', encoder='libx264', encoder_format='yuv420p') with writer.open(): writer.write_audio_chunk(0, audio.float().transpose(0, 1)) writer.write_video_chunk(1, video_chunk)