Spaces:
Runtime error
Runtime error
import os | |
import sys | |
video_file = """ | |
import gc | |
import math | |
import os | |
import re | |
import warnings | |
from fractions import Fraction | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
from ..utils import _log_api_usage_once | |
from . import _video_opt | |
try: | |
import av | |
av.logging.set_level(av.logging.ERROR) | |
if not hasattr(av.video.frame.VideoFrame, "pict_type"): | |
av = ImportError( | |
def _check_av_available() -> None: | |
if isinstance(av, Exception): | |
raise av | |
def _av_available() -> bool: | |
return not isinstance(av, Exception) | |
# PyAV has some reference cycles | |
_CALLED_TIMES = 0 | |
_GC_COLLECTION_INTERVAL = 10 | |
def write_video( | |
filename: str, | |
video_array: torch.Tensor, | |
fps: float, | |
video_codec: str = "libx264", | |
options: Optional[Dict[str, Any]] = None, | |
audio_array: Optional[torch.Tensor] = None, | |
audio_fps: Optional[float] = None, | |
audio_codec: Optional[str] = None, | |
audio_options: Optional[Dict[str, Any]] = None, | |
) -> None: | |
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
_log_api_usage_once(write_video) | |
_check_av_available() | |
video_array = torch.as_tensor(video_array, dtype=torch.uint8).cpu().numpy() | |
# PyAV does not support floating point numbers with decimal point | |
# and will throw OverflowException in case this is not the case | |
if isinstance(fps, float): | |
fps = np.round(fps) | |
with av.open(filename, mode="w") as container: | |
stream = container.add_stream(video_codec, rate=fps) | |
stream.width = video_array.shape[2] | |
stream.height = video_array.shape[1] | |
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24" | |
stream.options = options or {} | |
if audio_array is not None: | |
audio_format_dtypes = { | |
"dbl": "<f8", | |
"dblp": "<f8", | |
"flt": "<f4", | |
"fltp": "<f4", | |
"s16": "<i2", | |
"s16p": "<i2", | |
"s32": "<i4", | |
"s32p": "<i4", | |
"u8": "u1", | |
"u8p": "u1", | |
} | |
a_stream = container.add_stream(audio_codec, rate=audio_fps) | |
a_stream.options = audio_options or {} | |
num_channels = audio_array.shape[0] | |
audio_layout = "stereo" if num_channels > 1 else "mono" | |
audio_sample_fmt = container.streams.audio[0].format.name | |
format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt]) | |
audio_array = torch.as_tensor(audio_array).cpu().numpy().astype(format_dtype) | |
frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout) | |
frame.sample_rate = audio_fps | |
for packet in a_stream.encode(frame): | |
container.mux(packet) | |
for packet in a_stream.encode(): | |
container.mux(packet) | |
for img in video_array: | |
frame = av.VideoFrame.from_ndarray(img, format="rgb24") | |
frame.pict_type = "NONE" | |
for packet in stream.encode(frame): | |
container.mux(packet) | |
# Flush stream | |
for packet in stream.encode(): | |
container.mux(packet) | |
def _read_from_stream( | |
container: "av.container.Container", | |
start_offset: float, | |
end_offset: float, | |
pts_unit: str, | |
stream: "av.stream.Stream", | |
stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]], | |
) -> List["av.frame.Frame"]: | |
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL | |
_CALLED_TIMES += 1 | |
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: | |
gc.collect() | |
if pts_unit == "sec": | |
# TODO: we should change all of this from ground up to simply take | |
# sec and convert to MS in C++ | |
start_offset = int(math.floor(start_offset * (1 / stream.time_base))) | |
if end_offset != float("inf"): | |
end_offset = int(math.ceil(end_offset * (1 / stream.time_base))) | |
else: | |
warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.") | |
frames = {} | |
should_buffer = True | |
max_buffer_size = 5 | |
if stream.type == "video": | |
# DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt) | |
# so need to buffer some extra frames to sort everything | |
# properly | |
extradata = stream.codec_context.extradata | |
# overly complicated way of finding if `divx_packed` is set, following | |
# https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263 | |
if extradata and b"DivX" in extradata: | |
# can't use regex directly because of some weird characters sometimes... | |
pos = extradata.find(b"DivX") | |
d = extradata[pos:] | |
o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d) | |
if o is None: | |
o = re.search(rb"DivX(\d+)b(\d+)(\w)", d) | |
if o is not None: | |
should_buffer = o.group(3) == b"p" | |
seek_offset = start_offset | |
# some files don't seek to the right location, so better be safe here | |
seek_offset = max(seek_offset - 1, 0) | |
if should_buffer: | |
# FIXME this is kind of a hack, but we will jump to the previous keyframe | |
# so this will be safe | |
seek_offset = max(seek_offset - max_buffer_size, 0) | |
try: | |
# TODO check if stream needs to always be the video stream here or not | |
container.seek(seek_offset, any_frame=False, backward=True, stream=stream) | |
except av.AVError: | |
# TODO add some warnings in this case | |
# print("Corrupted file?", container.name) | |
return [] | |
buffer_count = 0 | |
try: | |
for _idx, frame in enumerate(container.decode(**stream_name)): | |
frames[frame.pts] = frame | |
if frame.pts >= end_offset: | |
if should_buffer and buffer_count < max_buffer_size: | |
buffer_count += 1 | |
continue | |
break | |
except av.AVError: | |
# TODO add a warning | |
pass | |
# ensure that the results are sorted wrt the pts | |
result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset] | |
if len(frames) > 0 and start_offset > 0 and start_offset not in frames: | |
# if there is no frame that exactly matches the pts of start_offset | |
# add the last frame smaller than start_offset, to guarantee that | |
# we will have all the necessary data. This is most useful for audio | |
preceding_frames = [i for i in frames if i < start_offset] | |
if len(preceding_frames) > 0: | |
first_frame_pts = max(preceding_frames) | |
result.insert(0, frames[first_frame_pts]) | |
return result | |
def _align_audio_frames( | |
aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float | |
) -> torch.Tensor: | |
start, end = audio_frames[0].pts, audio_frames[-1].pts | |
total_aframes = aframes.shape[1] | |
step_per_aframe = (end - start + 1) / total_aframes | |
s_idx = 0 | |
e_idx = total_aframes | |
if start < ref_start: | |
s_idx = int((ref_start - start) / step_per_aframe) | |
if end > ref_end: | |
e_idx = int((ref_end - end) / step_per_aframe) | |
return aframes[:, s_idx:e_idx] | |
def read_video( | |
filename: str, | |
start_pts: Union[float, Fraction] = 0, | |
end_pts: Optional[Union[float, Fraction]] = None, | |
pts_unit: str = "pts", | |
output_format: str = "THWC", | |
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: | |
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
_log_api_usage_once(read_video) | |
output_format = output_format.upper() | |
if output_format not in ("THWC", "TCHW"): | |
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") | |
from torchvision import get_video_backend | |
if not os.path.exists(filename): | |
raise RuntimeError(f"File not found: {filename}") | |
if get_video_backend() != "pyav": | |
vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit) | |
else: | |
_check_av_available() | |
if end_pts is None: | |
end_pts = float("inf") | |
if end_pts < start_pts: | |
raise ValueError( | |
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}" | |
) | |
info = {} | |
video_frames = [] | |
audio_frames = [] | |
audio_timebase = _video_opt.default_timebase | |
try: | |
with av.open(filename, metadata_errors="ignore") as container: | |
if container.streams.audio: | |
audio_timebase = container.streams.audio[0].time_base | |
if container.streams.video: | |
video_frames = _read_from_stream( | |
container, | |
start_pts, | |
end_pts, | |
pts_unit, | |
container.streams.video[0], | |
{"video": 0}, | |
) | |
video_fps = container.streams.video[0].average_rate | |
# guard against potentially corrupted files | |
if video_fps is not None: | |
info["video_fps"] = float(video_fps) | |
if container.streams.audio: | |
audio_frames = _read_from_stream( | |
container, | |
start_pts, | |
end_pts, | |
pts_unit, | |
container.streams.audio[0], | |
{"audio": 0}, | |
) | |
info["audio_fps"] = container.streams.audio[0].rate | |
except av.AVError: | |
# TODO raise a warning? | |
pass | |
vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames] | |
aframes_list = [frame.to_ndarray() for frame in audio_frames] | |
if vframes_list: | |
vframes = torch.as_tensor(np.stack(vframes_list)) | |
else: | |
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) | |
if aframes_list: | |
aframes = np.concatenate(aframes_list, 1) | |
aframes = torch.as_tensor(aframes) | |
if pts_unit == "sec": | |
start_pts = int(math.floor(start_pts * (1 / audio_timebase))) | |
if end_pts != float("inf"): | |
end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) | |
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) | |
else: | |
aframes = torch.empty((1, 0), dtype=torch.float32) | |
if output_format == "TCHW": | |
# [T,H,W,C] --> [T,C,H,W] | |
vframes = vframes.permute(0, 3, 1, 2) | |
return vframes, aframes, info | |
def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool: | |
extradata = container.streams[0].codec_context.extradata | |
if extradata is None: | |
return False | |
if b"Lavc" in extradata: | |
return True | |
return False | |
def _decode_video_timestamps(container: "av.container.Container") -> List[int]: | |
if _can_read_timestamps_from_packets(container): | |
# fast path | |
return [x.pts for x in container.demux(video=0) if x.pts is not None] | |
else: | |
return [x.pts for x in container.decode(video=0) if x.pts is not None] | |
def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]: | |
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
_log_api_usage_once(read_video_timestamps) | |
from torchvision import get_video_backend | |
if get_video_backend() != "pyav": | |
return _video_opt._read_video_timestamps(filename, pts_unit) | |
_check_av_available() | |
video_fps = None | |
pts = [] | |
try: | |
with av.open(filename, metadata_errors="ignore") as container: | |
if container.streams.video: | |
video_stream = container.streams.video[0] | |
video_time_base = video_stream.time_base | |
try: | |
pts = _decode_video_timestamps(container) | |
except av.AVError: | |
warnings.warn(f"Failed decoding frames for file {filename}") | |
video_fps = float(video_stream.average_rate) | |
except av.AVError as e: | |
msg = f"Failed to open container for {filename}; Caught error: {e}" | |
warnings.warn(msg, RuntimeWarning) | |
pts.sort() | |
if pts_unit == "sec": | |
pts = [x * video_time_base for x in pts] | |
return pts, video_fps | |
""" | |
def change(): | |
with open('/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torchvision/io/video.py','w') as f: | |
f.write(video_file) | |
f.close() | |
change() | |