Rex Cheng commited on
Commit
b0ec3f5
β€’
1 Parent(s): 164c335
Files changed (3) hide show
  1. app.py +5 -6
  2. demo.py +9 -9
  3. mmaudio/eval_utils.py +20 -58
app.py CHANGED
@@ -67,7 +67,10 @@ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int
67
  rng.manual_seed(seed)
68
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
69
 
70
- clip_frames, sync_frames, duration = load_video(video, duration)
 
 
 
71
  clip_frames = clip_frames.unsqueeze(0)
72
  sync_frames = sync_frames.unsqueeze(0)
73
  seq_cfg.duration = duration
@@ -87,11 +90,7 @@ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int
87
  video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
88
  # output_dir.mkdir(exist_ok=True, parents=True)
89
  # video_save_path = output_dir / f'{current_time_string}.mp4'
90
- make_video(video,
91
- video_save_path,
92
- audio,
93
- sampling_rate=seq_cfg.sampling_rate,
94
- duration_sec=seq_cfg.duration)
95
  log.info(f'Saved video to {video_save_path}')
96
  return video_save_path
97
 
 
67
  rng.manual_seed(seed)
68
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
69
 
70
+ video_info = load_video(video, duration)
71
+ clip_frames = video_info.clip_frames
72
+ sync_frames = video_info.sync_frames
73
+ duration = video_info.duration_sec
74
  clip_frames = clip_frames.unsqueeze(0)
75
  sync_frames = sync_frames.unsqueeze(0)
76
  seq_cfg.duration = duration
 
90
  video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
91
  # output_dir.mkdir(exist_ok=True, parents=True)
92
  # video_save_path = output_dir / f'{current_time_string}.mp4'
93
+ make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
 
 
 
 
94
  log.info(f'Saved video to {video_save_path}')
95
  return video_save_path
96
 
demo.py CHANGED
@@ -5,8 +5,8 @@ from pathlib import Path
5
  import torch
6
  import torchaudio
7
 
8
- from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate,
9
- load_video, make_video, setup_eval_logging)
10
  from mmaudio.model.flow_matching import FlowMatching
11
  from mmaudio.model.networks import MMAudio, get_my_mmaudio
12
  from mmaudio.model.utils.features_utils import FeaturesUtils
@@ -81,12 +81,16 @@ def main():
81
  synchformer_ckpt=model.synchformer_ckpt,
82
  enable_conditions=True,
83
  mode=model.mode,
84
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path)
 
85
  feature_utils = feature_utils.to(device, dtype).eval()
86
 
87
  if video_path is not None:
88
  log.info(f'Using video {video_path}')
89
- clip_frames, sync_frames, duration = load_video(video_path, duration)
 
 
 
90
  if mask_away_clip:
91
  clip_frames = None
92
  else:
@@ -121,11 +125,7 @@ def main():
121
  log.info(f'Audio saved to {save_path}')
122
  if video_path is not None and not skip_video_composite:
123
  video_save_path = output_dir / f'{video_path.stem}.mp4'
124
- make_video(video_path,
125
- video_save_path,
126
- audio,
127
- sampling_rate=seq_cfg.sampling_rate,
128
- duration_sec=seq_cfg.duration)
129
  log.info(f'Video saved to {output_dir / video_save_path}')
130
 
131
  log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
 
5
  import torch
6
  import torchaudio
7
 
8
+ from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
9
+ setup_eval_logging)
10
  from mmaudio.model.flow_matching import FlowMatching
11
  from mmaudio.model.networks import MMAudio, get_my_mmaudio
12
  from mmaudio.model.utils.features_utils import FeaturesUtils
 
81
  synchformer_ckpt=model.synchformer_ckpt,
82
  enable_conditions=True,
83
  mode=model.mode,
84
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
85
+ need_vae_encoder=False)
86
  feature_utils = feature_utils.to(device, dtype).eval()
87
 
88
  if video_path is not None:
89
  log.info(f'Using video {video_path}')
90
+ video_info = load_video(video_path, duration)
91
+ clip_frames = video_info.clip_frames
92
+ sync_frames = video_info.sync_frames
93
+ duration = video_info.duration_sec
94
  if mask_away_clip:
95
  clip_frames = None
96
  else:
 
125
  log.info(f'Audio saved to {save_path}')
126
  if video_path is not None and not skip_video_composite:
127
  video_save_path = output_dir / f'{video_path.stem}.mp4'
128
+ make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
 
 
 
 
129
  log.info(f'Video saved to {output_dir / video_save_path}')
130
 
131
  log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
mmaudio/eval_utils.py CHANGED
@@ -3,12 +3,11 @@ import logging
3
  from pathlib import Path
4
  from typing import Optional
5
 
6
- import av
7
  import torch
8
  from colorlog import ColoredFormatter
9
  from torchvision.transforms import v2
10
- from torio.io import StreamingMediaDecoder, StreamingMediaEncoder
11
 
 
12
  from mmaudio.model.flow_matching import FlowMatching
13
  from mmaudio.model.networks import MMAudio
14
  from mmaudio.model.sequence_config import (CONFIG_16K, CONFIG_44K, SequenceConfig)
@@ -154,7 +153,7 @@ def setup_eval_logging(log_level: int = logging.INFO):
154
  log.addHandler(stream)
155
 
156
 
157
- def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, torch.Tensor, float]:
158
  _CLIP_SIZE = 384
159
  _CLIP_FPS = 8.0
160
 
@@ -175,26 +174,15 @@ def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, tor
175
  v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
176
  ])
177
 
178
- reader = StreamingMediaDecoder(video_path)
179
- reader.add_basic_video_stream(
180
- frames_per_chunk=int(_CLIP_FPS * duration_sec),
181
- buffer_chunk_size=-1,
182
- frame_rate=_CLIP_FPS,
183
- format='rgb24',
184
- )
185
- reader.add_basic_video_stream(
186
- frames_per_chunk=int(_SYNC_FPS * duration_sec),
187
- buffer_chunk_size=-1,
188
- frame_rate=_SYNC_FPS,
189
- format='rgb24',
190
- )
191
 
192
- reader.fill_buffer()
193
- data_chunk = reader.pop_chunks()
194
- clip_chunk = data_chunk[0]
195
- sync_chunk = data_chunk[1]
196
- assert clip_chunk is not None
197
- assert sync_chunk is not None
198
 
199
  clip_frames = clip_transform(clip_chunk)
200
  sync_frames = sync_transform(sync_chunk)
@@ -215,41 +203,15 @@ def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, tor
215
  clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
216
  sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
217
 
218
- return clip_frames, sync_frames, duration_sec
219
-
220
-
221
- def make_video(video_path: Path, output_path: Path, audio: torch.Tensor, sampling_rate: int,
222
- duration_sec: float):
 
 
 
223
 
224
- av_video = av.open(video_path)
225
- frame_rate = av_video.streams.video[0].guessed_rate
226
 
227
- approx_max_length = int(duration_sec * frame_rate) + 1
228
- reader = StreamingMediaDecoder(video_path)
229
- reader.add_basic_video_stream(
230
- frames_per_chunk=approx_max_length,
231
- buffer_chunk_size=-1,
232
- format='rgb24',
233
- )
234
- reader.fill_buffer()
235
- video_chunk = reader.pop_chunks()[0]
236
- assert video_chunk is not None
237
-
238
- h, w = video_chunk.shape[-2:]
239
- video_chunk = video_chunk[:int(frame_rate * duration_sec)]
240
-
241
- writer = StreamingMediaEncoder(output_path)
242
- writer.add_audio_stream(
243
- sample_rate=sampling_rate,
244
- num_channels=audio.shape[0],
245
- encoder='aac', # 'flac' does not work for some reason?
246
- )
247
- writer.add_video_stream(frame_rate=frame_rate,
248
- width=w,
249
- height=h,
250
- format='rgb24',
251
- encoder='libx264',
252
- encoder_format='yuv420p')
253
- with writer.open():
254
- writer.write_audio_chunk(0, audio.float().transpose(0, 1))
255
- writer.write_video_chunk(1, video_chunk)
 
3
  from pathlib import Path
4
  from typing import Optional
5
 
 
6
  import torch
7
  from colorlog import ColoredFormatter
8
  from torchvision.transforms import v2
 
9
 
10
+ from mmaudio.data.av_utils import VideoInfo, read_frames, reencode_with_audio
11
  from mmaudio.model.flow_matching import FlowMatching
12
  from mmaudio.model.networks import MMAudio
13
  from mmaudio.model.sequence_config import (CONFIG_16K, CONFIG_44K, SequenceConfig)
 
153
  log.addHandler(stream)
154
 
155
 
156
+ def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
157
  _CLIP_SIZE = 384
158
  _CLIP_FPS = 8.0
159
 
 
174
  v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
175
  ])
176
 
177
+ output_frames, all_frames, orig_fps = read_frames(video_path,
178
+ list_of_fps=[_CLIP_FPS, _SYNC_FPS],
179
+ start_sec=0,
180
+ end_sec=duration_sec,
181
+ need_all_frames=load_all_frames)
 
 
 
 
 
 
 
 
182
 
183
+ clip_chunk, sync_chunk = output_frames
184
+ clip_chunk = torch.from_numpy(clip_chunk).permute(0, 3, 1, 2)
185
+ sync_chunk = torch.from_numpy(sync_chunk).permute(0, 3, 1, 2)
 
 
 
186
 
187
  clip_frames = clip_transform(clip_chunk)
188
  sync_frames = sync_transform(sync_chunk)
 
203
  clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
204
  sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
205
 
206
+ video_info = VideoInfo(
207
+ duration_sec=duration_sec,
208
+ fps=orig_fps,
209
+ clip_frames=clip_frames,
210
+ sync_frames=sync_frames,
211
+ all_frames=all_frames if load_all_frames else None,
212
+ )
213
+ return video_info
214
 
 
 
215
 
216
+ def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
217
+ reencode_with_audio(video_info, output_path, audio, sampling_rate)