MMAudio / mmaudio /data /av_utils.py
Rex Cheng
test
c58ca4b
raw
history blame
4.63 kB
from dataclasses import dataclass
from fractions import Fraction
from pathlib import Path
from typing import Optional
import av
import numpy as np
import torch
from av import AudioFrame
@dataclass
class VideoInfo:
duration_sec: float
fps: Fraction
clip_frames: torch.Tensor
sync_frames: torch.Tensor
all_frames: Optional[list[np.ndarray]]
@property
def height(self):
return self.all_frames[0].shape[0]
@property
def width(self):
return self.all_frames[0].shape[1]
def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
output_frames = [[] for _ in list_of_fps]
next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
all_frames = []
# container = av.open(video_path)
with av.open(video_path) as container:
stream = container.streams.video[0]
fps = stream.guessed_rate
stream.thread_type = 'AUTO'
for packet in container.demux(stream):
for frame in packet.decode():
frame_time = frame.time
if frame_time < start_sec:
continue
if frame_time > end_sec:
break
frame_np = None
if need_all_frames:
frame_np = frame.to_ndarray(format='rgb24')
all_frames.append(frame_np)
for i, _ in enumerate(list_of_fps):
this_time = frame_time
while this_time >= next_frame_time_for_each_fps[i]:
if frame_np is None:
frame_np = frame.to_ndarray(format='rgb24')
output_frames[i].append(frame_np)
next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
output_frames = [np.stack(frames) for frames in output_frames]
return output_frames, all_frames, fps
def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
sampling_rate: int):
container = av.open(output_path, 'w')
output_video_stream = container.add_stream('h264', video_info.fps)
output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
output_video_stream.width = video_info.width
output_video_stream.height = video_info.height
output_video_stream.pix_fmt = 'yuv420p'
output_audio_stream = container.add_stream('aac', sampling_rate)
# encode video
for image in video_info.all_frames:
image = av.VideoFrame.from_ndarray(image)
packet = output_video_stream.encode(image)
container.mux(packet)
for packet in output_video_stream.encode():
container.mux(packet)
# convert float tensor audio to numpy array
audio_np = audio.numpy().astype(np.float32)
audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
audio_frame.sample_rate = sampling_rate
for packet in output_audio_stream.encode(audio_frame):
container.mux(packet)
for packet in output_audio_stream.encode():
container.mux(packet)
container.close()
def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
"""
NOTE: I don't think we can get the exact video duration right without re-encoding
so we are not using this but keeping it here for reference
"""
video = av.open(video_path)
output = av.open(output_path, 'w')
input_video_stream = video.streams.video[0]
output_video_stream = output.add_stream(template=input_video_stream)
output_audio_stream = output.add_stream('aac', sampling_rate)
duration_sec = audio.shape[-1] / sampling_rate
for packet in video.demux(input_video_stream):
# We need to skip the "flushing" packets that `demux` generates.
if packet.dts is None:
continue
# We need to assign the packet to the new stream.
packet.stream = output_video_stream
output.mux(packet)
# convert float tensor audio to numpy array
audio_np = audio.numpy().astype(np.float32)
audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
audio_frame.sample_rate = sampling_rate
for packet in output_audio_stream.encode(audio_frame):
output.mux(packet)
for packet in output_audio_stream.encode():
output.mux(packet)
video.close()
output.close()
output.close()