NewLabs commited on
Commit
7dcb446
1 Parent(s): 5f5939a

Update change_video.py

Browse files
Files changed (1) hide show
  1. change_video.py +359 -0
change_video.py CHANGED
@@ -1,7 +1,366 @@
 
1
  import os
2
  import sys
3
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def change():
6
  with open('/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torchvision/io/video.py','w') as f:
7
  f.write(video_file)
 
1
+
2
  import os
3
  import sys
4
 
5
 
6
+ video_file = """
7
+ import gc
8
+ import math
9
+ import os
10
+ import re
11
+ import warnings
12
+ from fractions import Fraction
13
+ from typing import Any, Dict, List, Optional, Tuple, Union
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+ from ..utils import _log_api_usage_once
19
+ from . import _video_opt
20
+
21
+ try:
22
+ import av
23
+
24
+ av.logging.set_level(av.logging.ERROR)
25
+ if not hasattr(av.video.frame.VideoFrame, "pict_type"):
26
+ av = ImportError(
27
+
28
+
29
+
30
+ def _check_av_available() -> None:
31
+ if isinstance(av, Exception):
32
+ raise av
33
+
34
+
35
+ def _av_available() -> bool:
36
+ return not isinstance(av, Exception)
37
+
38
+
39
+ # PyAV has some reference cycles
40
+ _CALLED_TIMES = 0
41
+ _GC_COLLECTION_INTERVAL = 10
42
+
43
+
44
+ def write_video(
45
+ filename: str,
46
+ video_array: torch.Tensor,
47
+ fps: float,
48
+ video_codec: str = "libx264",
49
+ options: Optional[Dict[str, Any]] = None,
50
+ audio_array: Optional[torch.Tensor] = None,
51
+ audio_fps: Optional[float] = None,
52
+ audio_codec: Optional[str] = None,
53
+ audio_options: Optional[Dict[str, Any]] = None,
54
+ ) -> None:
55
+
56
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
57
+ _log_api_usage_once(write_video)
58
+ _check_av_available()
59
+ video_array = torch.as_tensor(video_array, dtype=torch.uint8).cpu().numpy()
60
+
61
+ # PyAV does not support floating point numbers with decimal point
62
+ # and will throw OverflowException in case this is not the case
63
+ if isinstance(fps, float):
64
+ fps = np.round(fps)
65
+
66
+ with av.open(filename, mode="w") as container:
67
+ stream = container.add_stream(video_codec, rate=fps)
68
+ stream.width = video_array.shape[2]
69
+ stream.height = video_array.shape[1]
70
+ stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
71
+ stream.options = options or {}
72
+
73
+ if audio_array is not None:
74
+ audio_format_dtypes = {
75
+ "dbl": "<f8",
76
+ "dblp": "<f8",
77
+ "flt": "<f4",
78
+ "fltp": "<f4",
79
+ "s16": "<i2",
80
+ "s16p": "<i2",
81
+ "s32": "<i4",
82
+ "s32p": "<i4",
83
+ "u8": "u1",
84
+ "u8p": "u1",
85
+ }
86
+ a_stream = container.add_stream(audio_codec, rate=audio_fps)
87
+ a_stream.options = audio_options or {}
88
+
89
+ num_channels = audio_array.shape[0]
90
+ audio_layout = "stereo" if num_channels > 1 else "mono"
91
+ audio_sample_fmt = container.streams.audio[0].format.name
92
+
93
+ format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt])
94
+ audio_array = torch.as_tensor(audio_array).cpu().numpy().astype(format_dtype)
95
+
96
+ frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout)
97
+
98
+ frame.sample_rate = audio_fps
99
+
100
+ for packet in a_stream.encode(frame):
101
+ container.mux(packet)
102
+
103
+ for packet in a_stream.encode():
104
+ container.mux(packet)
105
+
106
+ for img in video_array:
107
+ frame = av.VideoFrame.from_ndarray(img, format="rgb24")
108
+ frame.pict_type = "NONE"
109
+ for packet in stream.encode(frame):
110
+ container.mux(packet)
111
+
112
+ # Flush stream
113
+ for packet in stream.encode():
114
+ container.mux(packet)
115
+
116
+
117
+ def _read_from_stream(
118
+ container: "av.container.Container",
119
+ start_offset: float,
120
+ end_offset: float,
121
+ pts_unit: str,
122
+ stream: "av.stream.Stream",
123
+ stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
124
+ ) -> List["av.frame.Frame"]:
125
+ global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
126
+ _CALLED_TIMES += 1
127
+ if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
128
+ gc.collect()
129
+
130
+ if pts_unit == "sec":
131
+ # TODO: we should change all of this from ground up to simply take
132
+ # sec and convert to MS in C++
133
+ start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
134
+ if end_offset != float("inf"):
135
+ end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
136
+ else:
137
+ warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
138
+
139
+ frames = {}
140
+ should_buffer = True
141
+ max_buffer_size = 5
142
+ if stream.type == "video":
143
+ # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
144
+ # so need to buffer some extra frames to sort everything
145
+ # properly
146
+ extradata = stream.codec_context.extradata
147
+ # overly complicated way of finding if `divx_packed` is set, following
148
+ # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
149
+ if extradata and b"DivX" in extradata:
150
+ # can't use regex directly because of some weird characters sometimes...
151
+ pos = extradata.find(b"DivX")
152
+ d = extradata[pos:]
153
+ o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
154
+ if o is None:
155
+ o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
156
+ if o is not None:
157
+ should_buffer = o.group(3) == b"p"
158
+ seek_offset = start_offset
159
+ # some files don't seek to the right location, so better be safe here
160
+ seek_offset = max(seek_offset - 1, 0)
161
+ if should_buffer:
162
+ # FIXME this is kind of a hack, but we will jump to the previous keyframe
163
+ # so this will be safe
164
+ seek_offset = max(seek_offset - max_buffer_size, 0)
165
+ try:
166
+ # TODO check if stream needs to always be the video stream here or not
167
+ container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
168
+ except av.AVError:
169
+ # TODO add some warnings in this case
170
+ # print("Corrupted file?", container.name)
171
+ return []
172
+ buffer_count = 0
173
+ try:
174
+ for _idx, frame in enumerate(container.decode(**stream_name)):
175
+ frames[frame.pts] = frame
176
+ if frame.pts >= end_offset:
177
+ if should_buffer and buffer_count < max_buffer_size:
178
+ buffer_count += 1
179
+ continue
180
+ break
181
+ except av.AVError:
182
+ # TODO add a warning
183
+ pass
184
+ # ensure that the results are sorted wrt the pts
185
+ result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
186
+ if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
187
+ # if there is no frame that exactly matches the pts of start_offset
188
+ # add the last frame smaller than start_offset, to guarantee that
189
+ # we will have all the necessary data. This is most useful for audio
190
+ preceding_frames = [i for i in frames if i < start_offset]
191
+ if len(preceding_frames) > 0:
192
+ first_frame_pts = max(preceding_frames)
193
+ result.insert(0, frames[first_frame_pts])
194
+ return result
195
+
196
+
197
+ def _align_audio_frames(
198
+ aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float
199
+ ) -> torch.Tensor:
200
+ start, end = audio_frames[0].pts, audio_frames[-1].pts
201
+ total_aframes = aframes.shape[1]
202
+ step_per_aframe = (end - start + 1) / total_aframes
203
+ s_idx = 0
204
+ e_idx = total_aframes
205
+ if start < ref_start:
206
+ s_idx = int((ref_start - start) / step_per_aframe)
207
+ if end > ref_end:
208
+ e_idx = int((ref_end - end) / step_per_aframe)
209
+ return aframes[:, s_idx:e_idx]
210
+
211
+
212
+ def read_video(
213
+ filename: str,
214
+ start_pts: Union[float, Fraction] = 0,
215
+ end_pts: Optional[Union[float, Fraction]] = None,
216
+ pts_unit: str = "pts",
217
+ output_format: str = "THWC",
218
+ ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
219
+
220
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
221
+ _log_api_usage_once(read_video)
222
+
223
+ output_format = output_format.upper()
224
+ if output_format not in ("THWC", "TCHW"):
225
+ raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
226
+
227
+ from torchvision import get_video_backend
228
+
229
+ if not os.path.exists(filename):
230
+ raise RuntimeError(f"File not found: {filename}")
231
+
232
+ if get_video_backend() != "pyav":
233
+ vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
234
+ else:
235
+ _check_av_available()
236
+
237
+ if end_pts is None:
238
+ end_pts = float("inf")
239
+
240
+ if end_pts < start_pts:
241
+ raise ValueError(
242
+ f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
243
+ )
244
+
245
+ info = {}
246
+ video_frames = []
247
+ audio_frames = []
248
+ audio_timebase = _video_opt.default_timebase
249
+
250
+ try:
251
+ with av.open(filename, metadata_errors="ignore") as container:
252
+ if container.streams.audio:
253
+ audio_timebase = container.streams.audio[0].time_base
254
+ if container.streams.video:
255
+ video_frames = _read_from_stream(
256
+ container,
257
+ start_pts,
258
+ end_pts,
259
+ pts_unit,
260
+ container.streams.video[0],
261
+ {"video": 0},
262
+ )
263
+ video_fps = container.streams.video[0].average_rate
264
+ # guard against potentially corrupted files
265
+ if video_fps is not None:
266
+ info["video_fps"] = float(video_fps)
267
+
268
+ if container.streams.audio:
269
+ audio_frames = _read_from_stream(
270
+ container,
271
+ start_pts,
272
+ end_pts,
273
+ pts_unit,
274
+ container.streams.audio[0],
275
+ {"audio": 0},
276
+ )
277
+ info["audio_fps"] = container.streams.audio[0].rate
278
+
279
+ except av.AVError:
280
+ # TODO raise a warning?
281
+ pass
282
+
283
+ vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
284
+ aframes_list = [frame.to_ndarray() for frame in audio_frames]
285
+
286
+ if vframes_list:
287
+ vframes = torch.as_tensor(np.stack(vframes_list))
288
+ else:
289
+ vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
290
+
291
+ if aframes_list:
292
+ aframes = np.concatenate(aframes_list, 1)
293
+ aframes = torch.as_tensor(aframes)
294
+ if pts_unit == "sec":
295
+ start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
296
+ if end_pts != float("inf"):
297
+ end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
298
+ aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
299
+ else:
300
+ aframes = torch.empty((1, 0), dtype=torch.float32)
301
+
302
+ if output_format == "TCHW":
303
+ # [T,H,W,C] --> [T,C,H,W]
304
+ vframes = vframes.permute(0, 3, 1, 2)
305
+
306
+ return vframes, aframes, info
307
+
308
+
309
+ def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool:
310
+ extradata = container.streams[0].codec_context.extradata
311
+ if extradata is None:
312
+ return False
313
+ if b"Lavc" in extradata:
314
+ return True
315
+ return False
316
+
317
+
318
+ def _decode_video_timestamps(container: "av.container.Container") -> List[int]:
319
+ if _can_read_timestamps_from_packets(container):
320
+ # fast path
321
+ return [x.pts for x in container.demux(video=0) if x.pts is not None]
322
+ else:
323
+ return [x.pts for x in container.decode(video=0) if x.pts is not None]
324
+
325
+
326
+ def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]:
327
+
328
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
329
+ _log_api_usage_once(read_video_timestamps)
330
+ from torchvision import get_video_backend
331
+
332
+ if get_video_backend() != "pyav":
333
+ return _video_opt._read_video_timestamps(filename, pts_unit)
334
+
335
+ _check_av_available()
336
+
337
+ video_fps = None
338
+ pts = []
339
+
340
+ try:
341
+ with av.open(filename, metadata_errors="ignore") as container:
342
+ if container.streams.video:
343
+ video_stream = container.streams.video[0]
344
+ video_time_base = video_stream.time_base
345
+ try:
346
+ pts = _decode_video_timestamps(container)
347
+ except av.AVError:
348
+ warnings.warn(f"Failed decoding frames for file {filename}")
349
+ video_fps = float(video_stream.average_rate)
350
+ except av.AVError as e:
351
+ msg = f"Failed to open container for {filename}; Caught error: {e}"
352
+ warnings.warn(msg, RuntimeWarning)
353
+
354
+ pts.sort()
355
+
356
+ if pts_unit == "sec":
357
+ pts = [x * video_time_base for x in pts]
358
+
359
+ return pts, video_fps
360
+
361
+
362
+ """
363
+
364
  def change():
365
  with open('/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torchvision/io/video.py','w') as f:
366
  f.write(video_file)