Rex Cheng commited on
Commit
9ac63db
β€’
1 Parent(s): c8ca0bd
app.py CHANGED
@@ -83,14 +83,15 @@ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int
83
  audio = audios.float().cpu()[0]
84
 
85
  # current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
86
- video_save_path = tempfile.mktemp(suffix='.mp4')
87
  # output_dir.mkdir(exist_ok=True, parents=True)
88
  # video_save_path = output_dir / f'{current_time_string}.mp4'
89
- # make_video(video,
90
- # video_save_path,
91
- # audio,
92
- # sampling_rate=seq_cfg.sampling_rate,
93
- # duration_sec=seq_cfg.duration)
 
94
  return video_save_path
95
 
96
 
@@ -116,11 +117,9 @@ def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int,
116
  cfg_strength=cfg_strength)
117
  audio = audios.float().cpu()[0]
118
 
119
- # current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
120
- # output_dir.mkdir(exist_ok=True, parents=True)
121
- # audio_save_path = output_dir / f'{current_time_string}.flac'
122
- audio_save_path = tempfile.mktemp(suffix='.flac')
123
  torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
 
124
  return audio_save_path
125
 
126
 
@@ -140,8 +139,8 @@ video_to_audio_tab = gr.Interface(
140
  title='MMAudio β€” Video-to-Audio Synthesis',
141
  examples=[
142
  [
143
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
144
- '',
145
  '',
146
  0,
147
  25,
@@ -185,8 +184,8 @@ video_to_audio_tab = gr.Interface(
185
  10,
186
  ],
187
  [
188
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4',
189
- 'waves, seagulls',
190
  '',
191
  0,
192
  25,
@@ -194,8 +193,8 @@ video_to_audio_tab = gr.Interface(
194
  10,
195
  ],
196
  [
197
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_kraken.mp4',
198
- 'waves, storm',
199
  '',
200
  0,
201
  25,
 
83
  audio = audios.float().cpu()[0]
84
 
85
  # current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
86
+ video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
87
  # output_dir.mkdir(exist_ok=True, parents=True)
88
  # video_save_path = output_dir / f'{current_time_string}.mp4'
89
+ make_video(video,
90
+ video_save_path,
91
+ audio,
92
+ sampling_rate=seq_cfg.sampling_rate,
93
+ duration_sec=seq_cfg.duration)
94
+ log.info(f'Saved video to {video_save_path}')
95
  return video_save_path
96
 
97
 
 
117
  cfg_strength=cfg_strength)
118
  audio = audios.float().cpu()[0]
119
 
120
+ audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name
 
 
 
121
  torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
122
+ log.info(f'Saved audio to {audio_save_path}')
123
  return audio_save_path
124
 
125
 
 
139
  title='MMAudio β€” Video-to-Audio Synthesis',
140
  examples=[
141
  [
142
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4',
143
+ 'waves, seagulls',
144
  '',
145
  0,
146
  25,
 
184
  10,
185
  ],
186
  [
187
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_kraken.mp4',
188
+ 'waves, storm',
189
  '',
190
  0,
191
  25,
 
193
  10,
194
  ],
195
  [
196
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
197
+ '',
198
  '',
199
  0,
200
  25,
mmaudio/eval_utils.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  from colorlog import ColoredFormatter
8
  from torchvision.transforms import v2
9
  from torio.io import StreamingMediaDecoder, StreamingMediaEncoder
 
10
 
11
  from mmaudio.model.flow_matching import FlowMatching
12
  from mmaudio.model.networks import MMAudio
@@ -169,11 +170,13 @@ def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, tor
169
  reader = StreamingMediaDecoder(video_path)
170
  reader.add_basic_video_stream(
171
  frames_per_chunk=int(_CLIP_FPS * duration_sec),
 
172
  frame_rate=_CLIP_FPS,
173
  format='rgb24',
174
  )
175
  reader.add_basic_video_stream(
176
  frames_per_chunk=int(_SYNC_FPS * duration_sec),
 
177
  frame_rate=_SYNC_FPS,
178
  format='rgb24',
179
  )
@@ -182,9 +185,14 @@ def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, tor
182
  data_chunk = reader.pop_chunks()
183
  clip_chunk = data_chunk[0]
184
  sync_chunk = data_chunk[1]
 
 
185
  assert clip_chunk is not None
186
  assert sync_chunk is not None
187
 
 
 
 
188
  clip_frames = clip_transform(clip_chunk)
189
  sync_frames = sync_transform(sync_chunk)
190
 
@@ -210,17 +218,26 @@ def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, tor
210
  def make_video(video_path: Path, output_path: Path, audio: torch.Tensor, sampling_rate: int,
211
  duration_sec: float):
212
 
 
 
 
 
213
  approx_max_length = int(duration_sec * 60)
214
  reader = StreamingMediaDecoder(video_path)
215
  reader.add_basic_video_stream(
216
  frames_per_chunk=approx_max_length,
 
217
  format='rgb24',
218
  )
219
  reader.fill_buffer()
220
  video_chunk = reader.pop_chunks()[0]
 
221
  assert video_chunk is not None
222
 
223
- fps = int(reader.get_out_stream_info(0).frame_rate)
 
 
 
224
  if fps > 60:
225
  log.warning(f'This code supports only up to 60 fps, but the video has {fps} fps')
226
  log.warning(f'Just change the *60 above me')
 
7
  from colorlog import ColoredFormatter
8
  from torchvision.transforms import v2
9
  from torio.io import StreamingMediaDecoder, StreamingMediaEncoder
10
+ import av
11
 
12
  from mmaudio.model.flow_matching import FlowMatching
13
  from mmaudio.model.networks import MMAudio
 
170
  reader = StreamingMediaDecoder(video_path)
171
  reader.add_basic_video_stream(
172
  frames_per_chunk=int(_CLIP_FPS * duration_sec),
173
+ buffer_chunk_size=1,
174
  frame_rate=_CLIP_FPS,
175
  format='rgb24',
176
  )
177
  reader.add_basic_video_stream(
178
  frames_per_chunk=int(_SYNC_FPS * duration_sec),
179
+ buffer_chunk_size=1,
180
  frame_rate=_SYNC_FPS,
181
  format='rgb24',
182
  )
 
185
  data_chunk = reader.pop_chunks()
186
  clip_chunk = data_chunk[0]
187
  sync_chunk = data_chunk[1]
188
+ print('clip', clip_chunk.shape, clip_chunk.dtype, clip_chunk.max())
189
+ print('sync', sync_chunk.shape, sync_chunk.dtype, sync_chunk.max())
190
  assert clip_chunk is not None
191
  assert sync_chunk is not None
192
 
193
+ for i in range(reader.num_out_streams):
194
+ print(reader.get_out_stream_info(i))
195
+
196
  clip_frames = clip_transform(clip_chunk)
197
  sync_frames = sync_transform(sync_chunk)
198
 
 
218
  def make_video(video_path: Path, output_path: Path, audio: torch.Tensor, sampling_rate: int,
219
  duration_sec: float):
220
 
221
+ av_video = av.open(video_path)
222
+ frame_rate = av_video.streams.video[0].guessed_rate
223
+ print('av frame rate', frame_rate)
224
+
225
  approx_max_length = int(duration_sec * 60)
226
  reader = StreamingMediaDecoder(video_path)
227
  reader.add_basic_video_stream(
228
  frames_per_chunk=approx_max_length,
229
+ buffer_chunk_size=1,
230
  format='rgb24',
231
  )
232
  reader.fill_buffer()
233
  video_chunk = reader.pop_chunks()[0]
234
+ print(video_chunk.shape, video_chunk.dtype, video_chunk.max())
235
  assert video_chunk is not None
236
 
237
+ # fps = int(reader.get_out_stream_info(0).frame_rate)
238
+ fps = frame_rate
239
+ for i in range(reader.num_out_streams):
240
+ print(reader.get_out_stream_info(i))
241
  if fps > 60:
242
  log.warning(f'This code supports only up to 60 fps, but the video has {fps} fps')
243
  log.warning(f'Just change the *60 above me')
mmaudio/utils/download_utils.py CHANGED
@@ -30,7 +30,8 @@ links = [
30
  },
31
  {
32
  'name': 'mmaudio_large_44k_v2.pth',
33
- 'url': 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k_v2.pth',
 
34
  'md5': '01ad4464f049b2d7efdaa4c1a59b8dfe'
35
  },
36
  {
 
30
  },
31
  {
32
  'name': 'mmaudio_large_44k_v2.pth',
33
+ # 'url': 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k_v2.pth',
34
+ 'url': 'https://databank.illinois.edu/datafiles/i1pd9/download',
35
  'md5': '01ad4464f049b2d7efdaa4c1a59b8dfe'
36
  },
37
  {
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- torch >= 2.5.1
2
- torchaudio
3
  torchvision
 
4
  python-dotenv
5
  cython
6
  gitpython >= 3.1
@@ -23,4 +23,5 @@ hydra_colorlog
23
  tensordict
24
  colorlog
25
  open_clip_torch
26
- soundfile
 
 
1
+ torch == 2.4.0
 
2
  torchvision
3
+ torchaudio
4
  python-dotenv
5
  cython
6
  gitpython >= 3.1
 
23
  tensordict
24
  colorlog
25
  open_clip_torch
26
+ soundfile
27
+ av