Spaces:
Sleeping
Sleeping
Update change_video.py
Browse files- change_video.py +359 -0
change_video.py
CHANGED
@@ -1,7 +1,366 @@
|
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
def change():
|
6 |
with open('/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torchvision/io/video.py','w') as f:
|
7 |
f.write(video_file)
|
|
|
1 |
+
|
2 |
import os
|
3 |
import sys
|
4 |
|
5 |
|
6 |
+
video_file = """
|
7 |
+
import gc
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
import re
|
11 |
+
import warnings
|
12 |
+
from fractions import Fraction
|
13 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from ..utils import _log_api_usage_once
|
19 |
+
from . import _video_opt
|
20 |
+
|
21 |
+
try:
|
22 |
+
import av
|
23 |
+
|
24 |
+
av.logging.set_level(av.logging.ERROR)
|
25 |
+
if not hasattr(av.video.frame.VideoFrame, "pict_type"):
|
26 |
+
av = ImportError(
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
def _check_av_available() -> None:
|
31 |
+
if isinstance(av, Exception):
|
32 |
+
raise av
|
33 |
+
|
34 |
+
|
35 |
+
def _av_available() -> bool:
|
36 |
+
return not isinstance(av, Exception)
|
37 |
+
|
38 |
+
|
39 |
+
# PyAV has some reference cycles
|
40 |
+
_CALLED_TIMES = 0
|
41 |
+
_GC_COLLECTION_INTERVAL = 10
|
42 |
+
|
43 |
+
|
44 |
+
def write_video(
|
45 |
+
filename: str,
|
46 |
+
video_array: torch.Tensor,
|
47 |
+
fps: float,
|
48 |
+
video_codec: str = "libx264",
|
49 |
+
options: Optional[Dict[str, Any]] = None,
|
50 |
+
audio_array: Optional[torch.Tensor] = None,
|
51 |
+
audio_fps: Optional[float] = None,
|
52 |
+
audio_codec: Optional[str] = None,
|
53 |
+
audio_options: Optional[Dict[str, Any]] = None,
|
54 |
+
) -> None:
|
55 |
+
|
56 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
57 |
+
_log_api_usage_once(write_video)
|
58 |
+
_check_av_available()
|
59 |
+
video_array = torch.as_tensor(video_array, dtype=torch.uint8).cpu().numpy()
|
60 |
+
|
61 |
+
# PyAV does not support floating point numbers with decimal point
|
62 |
+
# and will throw OverflowException in case this is not the case
|
63 |
+
if isinstance(fps, float):
|
64 |
+
fps = np.round(fps)
|
65 |
+
|
66 |
+
with av.open(filename, mode="w") as container:
|
67 |
+
stream = container.add_stream(video_codec, rate=fps)
|
68 |
+
stream.width = video_array.shape[2]
|
69 |
+
stream.height = video_array.shape[1]
|
70 |
+
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
|
71 |
+
stream.options = options or {}
|
72 |
+
|
73 |
+
if audio_array is not None:
|
74 |
+
audio_format_dtypes = {
|
75 |
+
"dbl": "<f8",
|
76 |
+
"dblp": "<f8",
|
77 |
+
"flt": "<f4",
|
78 |
+
"fltp": "<f4",
|
79 |
+
"s16": "<i2",
|
80 |
+
"s16p": "<i2",
|
81 |
+
"s32": "<i4",
|
82 |
+
"s32p": "<i4",
|
83 |
+
"u8": "u1",
|
84 |
+
"u8p": "u1",
|
85 |
+
}
|
86 |
+
a_stream = container.add_stream(audio_codec, rate=audio_fps)
|
87 |
+
a_stream.options = audio_options or {}
|
88 |
+
|
89 |
+
num_channels = audio_array.shape[0]
|
90 |
+
audio_layout = "stereo" if num_channels > 1 else "mono"
|
91 |
+
audio_sample_fmt = container.streams.audio[0].format.name
|
92 |
+
|
93 |
+
format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt])
|
94 |
+
audio_array = torch.as_tensor(audio_array).cpu().numpy().astype(format_dtype)
|
95 |
+
|
96 |
+
frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout)
|
97 |
+
|
98 |
+
frame.sample_rate = audio_fps
|
99 |
+
|
100 |
+
for packet in a_stream.encode(frame):
|
101 |
+
container.mux(packet)
|
102 |
+
|
103 |
+
for packet in a_stream.encode():
|
104 |
+
container.mux(packet)
|
105 |
+
|
106 |
+
for img in video_array:
|
107 |
+
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
|
108 |
+
frame.pict_type = "NONE"
|
109 |
+
for packet in stream.encode(frame):
|
110 |
+
container.mux(packet)
|
111 |
+
|
112 |
+
# Flush stream
|
113 |
+
for packet in stream.encode():
|
114 |
+
container.mux(packet)
|
115 |
+
|
116 |
+
|
117 |
+
def _read_from_stream(
|
118 |
+
container: "av.container.Container",
|
119 |
+
start_offset: float,
|
120 |
+
end_offset: float,
|
121 |
+
pts_unit: str,
|
122 |
+
stream: "av.stream.Stream",
|
123 |
+
stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
|
124 |
+
) -> List["av.frame.Frame"]:
|
125 |
+
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
|
126 |
+
_CALLED_TIMES += 1
|
127 |
+
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
|
128 |
+
gc.collect()
|
129 |
+
|
130 |
+
if pts_unit == "sec":
|
131 |
+
# TODO: we should change all of this from ground up to simply take
|
132 |
+
# sec and convert to MS in C++
|
133 |
+
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
|
134 |
+
if end_offset != float("inf"):
|
135 |
+
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
|
136 |
+
else:
|
137 |
+
warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
|
138 |
+
|
139 |
+
frames = {}
|
140 |
+
should_buffer = True
|
141 |
+
max_buffer_size = 5
|
142 |
+
if stream.type == "video":
|
143 |
+
# DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
|
144 |
+
# so need to buffer some extra frames to sort everything
|
145 |
+
# properly
|
146 |
+
extradata = stream.codec_context.extradata
|
147 |
+
# overly complicated way of finding if `divx_packed` is set, following
|
148 |
+
# https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
|
149 |
+
if extradata and b"DivX" in extradata:
|
150 |
+
# can't use regex directly because of some weird characters sometimes...
|
151 |
+
pos = extradata.find(b"DivX")
|
152 |
+
d = extradata[pos:]
|
153 |
+
o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
|
154 |
+
if o is None:
|
155 |
+
o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
|
156 |
+
if o is not None:
|
157 |
+
should_buffer = o.group(3) == b"p"
|
158 |
+
seek_offset = start_offset
|
159 |
+
# some files don't seek to the right location, so better be safe here
|
160 |
+
seek_offset = max(seek_offset - 1, 0)
|
161 |
+
if should_buffer:
|
162 |
+
# FIXME this is kind of a hack, but we will jump to the previous keyframe
|
163 |
+
# so this will be safe
|
164 |
+
seek_offset = max(seek_offset - max_buffer_size, 0)
|
165 |
+
try:
|
166 |
+
# TODO check if stream needs to always be the video stream here or not
|
167 |
+
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
|
168 |
+
except av.AVError:
|
169 |
+
# TODO add some warnings in this case
|
170 |
+
# print("Corrupted file?", container.name)
|
171 |
+
return []
|
172 |
+
buffer_count = 0
|
173 |
+
try:
|
174 |
+
for _idx, frame in enumerate(container.decode(**stream_name)):
|
175 |
+
frames[frame.pts] = frame
|
176 |
+
if frame.pts >= end_offset:
|
177 |
+
if should_buffer and buffer_count < max_buffer_size:
|
178 |
+
buffer_count += 1
|
179 |
+
continue
|
180 |
+
break
|
181 |
+
except av.AVError:
|
182 |
+
# TODO add a warning
|
183 |
+
pass
|
184 |
+
# ensure that the results are sorted wrt the pts
|
185 |
+
result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
|
186 |
+
if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
|
187 |
+
# if there is no frame that exactly matches the pts of start_offset
|
188 |
+
# add the last frame smaller than start_offset, to guarantee that
|
189 |
+
# we will have all the necessary data. This is most useful for audio
|
190 |
+
preceding_frames = [i for i in frames if i < start_offset]
|
191 |
+
if len(preceding_frames) > 0:
|
192 |
+
first_frame_pts = max(preceding_frames)
|
193 |
+
result.insert(0, frames[first_frame_pts])
|
194 |
+
return result
|
195 |
+
|
196 |
+
|
197 |
+
def _align_audio_frames(
|
198 |
+
aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float
|
199 |
+
) -> torch.Tensor:
|
200 |
+
start, end = audio_frames[0].pts, audio_frames[-1].pts
|
201 |
+
total_aframes = aframes.shape[1]
|
202 |
+
step_per_aframe = (end - start + 1) / total_aframes
|
203 |
+
s_idx = 0
|
204 |
+
e_idx = total_aframes
|
205 |
+
if start < ref_start:
|
206 |
+
s_idx = int((ref_start - start) / step_per_aframe)
|
207 |
+
if end > ref_end:
|
208 |
+
e_idx = int((ref_end - end) / step_per_aframe)
|
209 |
+
return aframes[:, s_idx:e_idx]
|
210 |
+
|
211 |
+
|
212 |
+
def read_video(
|
213 |
+
filename: str,
|
214 |
+
start_pts: Union[float, Fraction] = 0,
|
215 |
+
end_pts: Optional[Union[float, Fraction]] = None,
|
216 |
+
pts_unit: str = "pts",
|
217 |
+
output_format: str = "THWC",
|
218 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
219 |
+
|
220 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
221 |
+
_log_api_usage_once(read_video)
|
222 |
+
|
223 |
+
output_format = output_format.upper()
|
224 |
+
if output_format not in ("THWC", "TCHW"):
|
225 |
+
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
|
226 |
+
|
227 |
+
from torchvision import get_video_backend
|
228 |
+
|
229 |
+
if not os.path.exists(filename):
|
230 |
+
raise RuntimeError(f"File not found: {filename}")
|
231 |
+
|
232 |
+
if get_video_backend() != "pyav":
|
233 |
+
vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
|
234 |
+
else:
|
235 |
+
_check_av_available()
|
236 |
+
|
237 |
+
if end_pts is None:
|
238 |
+
end_pts = float("inf")
|
239 |
+
|
240 |
+
if end_pts < start_pts:
|
241 |
+
raise ValueError(
|
242 |
+
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
|
243 |
+
)
|
244 |
+
|
245 |
+
info = {}
|
246 |
+
video_frames = []
|
247 |
+
audio_frames = []
|
248 |
+
audio_timebase = _video_opt.default_timebase
|
249 |
+
|
250 |
+
try:
|
251 |
+
with av.open(filename, metadata_errors="ignore") as container:
|
252 |
+
if container.streams.audio:
|
253 |
+
audio_timebase = container.streams.audio[0].time_base
|
254 |
+
if container.streams.video:
|
255 |
+
video_frames = _read_from_stream(
|
256 |
+
container,
|
257 |
+
start_pts,
|
258 |
+
end_pts,
|
259 |
+
pts_unit,
|
260 |
+
container.streams.video[0],
|
261 |
+
{"video": 0},
|
262 |
+
)
|
263 |
+
video_fps = container.streams.video[0].average_rate
|
264 |
+
# guard against potentially corrupted files
|
265 |
+
if video_fps is not None:
|
266 |
+
info["video_fps"] = float(video_fps)
|
267 |
+
|
268 |
+
if container.streams.audio:
|
269 |
+
audio_frames = _read_from_stream(
|
270 |
+
container,
|
271 |
+
start_pts,
|
272 |
+
end_pts,
|
273 |
+
pts_unit,
|
274 |
+
container.streams.audio[0],
|
275 |
+
{"audio": 0},
|
276 |
+
)
|
277 |
+
info["audio_fps"] = container.streams.audio[0].rate
|
278 |
+
|
279 |
+
except av.AVError:
|
280 |
+
# TODO raise a warning?
|
281 |
+
pass
|
282 |
+
|
283 |
+
vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
|
284 |
+
aframes_list = [frame.to_ndarray() for frame in audio_frames]
|
285 |
+
|
286 |
+
if vframes_list:
|
287 |
+
vframes = torch.as_tensor(np.stack(vframes_list))
|
288 |
+
else:
|
289 |
+
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
|
290 |
+
|
291 |
+
if aframes_list:
|
292 |
+
aframes = np.concatenate(aframes_list, 1)
|
293 |
+
aframes = torch.as_tensor(aframes)
|
294 |
+
if pts_unit == "sec":
|
295 |
+
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
|
296 |
+
if end_pts != float("inf"):
|
297 |
+
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
|
298 |
+
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
|
299 |
+
else:
|
300 |
+
aframes = torch.empty((1, 0), dtype=torch.float32)
|
301 |
+
|
302 |
+
if output_format == "TCHW":
|
303 |
+
# [T,H,W,C] --> [T,C,H,W]
|
304 |
+
vframes = vframes.permute(0, 3, 1, 2)
|
305 |
+
|
306 |
+
return vframes, aframes, info
|
307 |
+
|
308 |
+
|
309 |
+
def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool:
|
310 |
+
extradata = container.streams[0].codec_context.extradata
|
311 |
+
if extradata is None:
|
312 |
+
return False
|
313 |
+
if b"Lavc" in extradata:
|
314 |
+
return True
|
315 |
+
return False
|
316 |
+
|
317 |
+
|
318 |
+
def _decode_video_timestamps(container: "av.container.Container") -> List[int]:
|
319 |
+
if _can_read_timestamps_from_packets(container):
|
320 |
+
# fast path
|
321 |
+
return [x.pts for x in container.demux(video=0) if x.pts is not None]
|
322 |
+
else:
|
323 |
+
return [x.pts for x in container.decode(video=0) if x.pts is not None]
|
324 |
+
|
325 |
+
|
326 |
+
def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]:
|
327 |
+
|
328 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
329 |
+
_log_api_usage_once(read_video_timestamps)
|
330 |
+
from torchvision import get_video_backend
|
331 |
+
|
332 |
+
if get_video_backend() != "pyav":
|
333 |
+
return _video_opt._read_video_timestamps(filename, pts_unit)
|
334 |
+
|
335 |
+
_check_av_available()
|
336 |
+
|
337 |
+
video_fps = None
|
338 |
+
pts = []
|
339 |
+
|
340 |
+
try:
|
341 |
+
with av.open(filename, metadata_errors="ignore") as container:
|
342 |
+
if container.streams.video:
|
343 |
+
video_stream = container.streams.video[0]
|
344 |
+
video_time_base = video_stream.time_base
|
345 |
+
try:
|
346 |
+
pts = _decode_video_timestamps(container)
|
347 |
+
except av.AVError:
|
348 |
+
warnings.warn(f"Failed decoding frames for file {filename}")
|
349 |
+
video_fps = float(video_stream.average_rate)
|
350 |
+
except av.AVError as e:
|
351 |
+
msg = f"Failed to open container for {filename}; Caught error: {e}"
|
352 |
+
warnings.warn(msg, RuntimeWarning)
|
353 |
+
|
354 |
+
pts.sort()
|
355 |
+
|
356 |
+
if pts_unit == "sec":
|
357 |
+
pts = [x * video_time_base for x in pts]
|
358 |
+
|
359 |
+
return pts, video_fps
|
360 |
+
|
361 |
+
|
362 |
+
"""
|
363 |
+
|
364 |
def change():
|
365 |
with open('/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torchvision/io/video.py','w') as f:
|
366 |
f.write(video_file)
|