Rex Cheng commited on
Commit
c58ca4b
1 Parent(s): b0ec3f5
mmaudio/data/__init__.py ADDED
File without changes
mmaudio/data/av_utils.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from fractions import Fraction
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import av
7
+ import numpy as np
8
+ import torch
9
+ from av import AudioFrame
10
+
11
+
12
+ @dataclass
13
+ class VideoInfo:
14
+ duration_sec: float
15
+ fps: Fraction
16
+ clip_frames: torch.Tensor
17
+ sync_frames: torch.Tensor
18
+ all_frames: Optional[list[np.ndarray]]
19
+
20
+ @property
21
+ def height(self):
22
+ return self.all_frames[0].shape[0]
23
+
24
+ @property
25
+ def width(self):
26
+ return self.all_frames[0].shape[1]
27
+
28
+
29
+ def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
30
+ need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
31
+ output_frames = [[] for _ in list_of_fps]
32
+ next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
33
+ time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
34
+ all_frames = []
35
+
36
+ # container = av.open(video_path)
37
+ with av.open(video_path) as container:
38
+ stream = container.streams.video[0]
39
+ fps = stream.guessed_rate
40
+ stream.thread_type = 'AUTO'
41
+ for packet in container.demux(stream):
42
+ for frame in packet.decode():
43
+ frame_time = frame.time
44
+ if frame_time < start_sec:
45
+ continue
46
+ if frame_time > end_sec:
47
+ break
48
+
49
+ frame_np = None
50
+ if need_all_frames:
51
+ frame_np = frame.to_ndarray(format='rgb24')
52
+ all_frames.append(frame_np)
53
+
54
+ for i, _ in enumerate(list_of_fps):
55
+ this_time = frame_time
56
+ while this_time >= next_frame_time_for_each_fps[i]:
57
+ if frame_np is None:
58
+ frame_np = frame.to_ndarray(format='rgb24')
59
+
60
+ output_frames[i].append(frame_np)
61
+ next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
62
+
63
+ output_frames = [np.stack(frames) for frames in output_frames]
64
+ return output_frames, all_frames, fps
65
+
66
+
67
+ def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
68
+ sampling_rate: int):
69
+ container = av.open(output_path, 'w')
70
+ output_video_stream = container.add_stream('h264', video_info.fps)
71
+ output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
72
+ output_video_stream.width = video_info.width
73
+ output_video_stream.height = video_info.height
74
+ output_video_stream.pix_fmt = 'yuv420p'
75
+
76
+ output_audio_stream = container.add_stream('aac', sampling_rate)
77
+
78
+ # encode video
79
+ for image in video_info.all_frames:
80
+ image = av.VideoFrame.from_ndarray(image)
81
+ packet = output_video_stream.encode(image)
82
+ container.mux(packet)
83
+
84
+ for packet in output_video_stream.encode():
85
+ container.mux(packet)
86
+
87
+ # convert float tensor audio to numpy array
88
+ audio_np = audio.numpy().astype(np.float32)
89
+ audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
90
+ audio_frame.sample_rate = sampling_rate
91
+
92
+ for packet in output_audio_stream.encode(audio_frame):
93
+ container.mux(packet)
94
+
95
+ for packet in output_audio_stream.encode():
96
+ container.mux(packet)
97
+
98
+ container.close()
99
+
100
+
101
+ def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
102
+ """
103
+ NOTE: I don't think we can get the exact video duration right without re-encoding
104
+ so we are not using this but keeping it here for reference
105
+ """
106
+ video = av.open(video_path)
107
+ output = av.open(output_path, 'w')
108
+ input_video_stream = video.streams.video[0]
109
+ output_video_stream = output.add_stream(template=input_video_stream)
110
+ output_audio_stream = output.add_stream('aac', sampling_rate)
111
+
112
+ duration_sec = audio.shape[-1] / sampling_rate
113
+
114
+ for packet in video.demux(input_video_stream):
115
+ # We need to skip the "flushing" packets that `demux` generates.
116
+ if packet.dts is None:
117
+ continue
118
+ # We need to assign the packet to the new stream.
119
+ packet.stream = output_video_stream
120
+ output.mux(packet)
121
+
122
+ # convert float tensor audio to numpy array
123
+ audio_np = audio.numpy().astype(np.float32)
124
+ audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
125
+ audio_frame.sample_rate = sampling_rate
126
+
127
+ for packet in output_audio_stream.encode(audio_frame):
128
+ output.mux(packet)
129
+
130
+ for packet in output_audio_stream.encode():
131
+ output.mux(packet)
132
+
133
+ video.close()
134
+ output.close()
135
+
136
+ output.close()