Rex Cheng commited on
Commit
35d9519
1 Parent(s): 9ac63db

fix streamreader

Browse files
Files changed (1) hide show
  1. mmaudio/eval_utils.py +7 -17
mmaudio/eval_utils.py CHANGED
@@ -3,11 +3,11 @@ import logging
3
  from pathlib import Path
4
  from typing import Optional
5
 
 
6
  import torch
7
  from colorlog import ColoredFormatter
8
  from torchvision.transforms import v2
9
  from torio.io import StreamingMediaDecoder, StreamingMediaEncoder
10
- import av
11
 
12
  from mmaudio.model.flow_matching import FlowMatching
13
  from mmaudio.model.networks import MMAudio
@@ -170,13 +170,13 @@ def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, tor
170
  reader = StreamingMediaDecoder(video_path)
171
  reader.add_basic_video_stream(
172
  frames_per_chunk=int(_CLIP_FPS * duration_sec),
173
- buffer_chunk_size=1,
174
  frame_rate=_CLIP_FPS,
175
  format='rgb24',
176
  )
177
  reader.add_basic_video_stream(
178
  frames_per_chunk=int(_SYNC_FPS * duration_sec),
179
- buffer_chunk_size=1,
180
  frame_rate=_SYNC_FPS,
181
  format='rgb24',
182
  )
@@ -220,30 +220,20 @@ def make_video(video_path: Path, output_path: Path, audio: torch.Tensor, samplin
220
 
221
  av_video = av.open(video_path)
222
  frame_rate = av_video.streams.video[0].guessed_rate
223
- print('av frame rate', frame_rate)
224
 
225
- approx_max_length = int(duration_sec * 60)
226
  reader = StreamingMediaDecoder(video_path)
227
  reader.add_basic_video_stream(
228
  frames_per_chunk=approx_max_length,
229
- buffer_chunk_size=1,
230
  format='rgb24',
231
  )
232
  reader.fill_buffer()
233
  video_chunk = reader.pop_chunks()[0]
234
- print(video_chunk.shape, video_chunk.dtype, video_chunk.max())
235
  assert video_chunk is not None
236
 
237
- # fps = int(reader.get_out_stream_info(0).frame_rate)
238
- fps = frame_rate
239
- for i in range(reader.num_out_streams):
240
- print(reader.get_out_stream_info(i))
241
- if fps > 60:
242
- log.warning(f'This code supports only up to 60 fps, but the video has {fps} fps')
243
- log.warning(f'Just change the *60 above me')
244
-
245
  h, w = video_chunk.shape[-2:]
246
- video_chunk = video_chunk[:int(fps * duration_sec)]
247
 
248
  writer = StreamingMediaEncoder(output_path)
249
  writer.add_audio_stream(
@@ -251,7 +241,7 @@ def make_video(video_path: Path, output_path: Path, audio: torch.Tensor, samplin
251
  num_channels=audio.shape[0],
252
  encoder='aac', # 'flac' does not work for some reason?
253
  )
254
- writer.add_video_stream(frame_rate=fps,
255
  width=w,
256
  height=h,
257
  format='rgb24',
 
3
  from pathlib import Path
4
  from typing import Optional
5
 
6
+ import av
7
  import torch
8
  from colorlog import ColoredFormatter
9
  from torchvision.transforms import v2
10
  from torio.io import StreamingMediaDecoder, StreamingMediaEncoder
 
11
 
12
  from mmaudio.model.flow_matching import FlowMatching
13
  from mmaudio.model.networks import MMAudio
 
170
  reader = StreamingMediaDecoder(video_path)
171
  reader.add_basic_video_stream(
172
  frames_per_chunk=int(_CLIP_FPS * duration_sec),
173
+ buffer_chunk_size=-1,
174
  frame_rate=_CLIP_FPS,
175
  format='rgb24',
176
  )
177
  reader.add_basic_video_stream(
178
  frames_per_chunk=int(_SYNC_FPS * duration_sec),
179
+ buffer_chunk_size=-1,
180
  frame_rate=_SYNC_FPS,
181
  format='rgb24',
182
  )
 
220
 
221
  av_video = av.open(video_path)
222
  frame_rate = av_video.streams.video[0].guessed_rate
 
223
 
224
+ approx_max_length = int(duration_sec * frame_rate) + 1
225
  reader = StreamingMediaDecoder(video_path)
226
  reader.add_basic_video_stream(
227
  frames_per_chunk=approx_max_length,
228
+ buffer_chunk_size=-1,
229
  format='rgb24',
230
  )
231
  reader.fill_buffer()
232
  video_chunk = reader.pop_chunks()[0]
 
233
  assert video_chunk is not None
234
 
 
 
 
 
 
 
 
 
235
  h, w = video_chunk.shape[-2:]
236
+ video_chunk = video_chunk[:int(frame_rate * duration_sec)]
237
 
238
  writer = StreamingMediaEncoder(output_path)
239
  writer.add_audio_stream(
 
241
  num_channels=audio.shape[0],
242
  encoder='aac', # 'flac' does not work for some reason?
243
  )
244
+ writer.add_video_stream(frame_rate=frame_rate,
245
  width=w,
246
  height=h,
247
  format='rgb24',