File size: 3,480 Bytes
b4754db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import importlib
import os
import os.path as osp
import shutil
import sys
from pathlib import Path
# import av
import numpy as np
import torch
import torchvision
from einops import rearrange
from PIL import Image
def seed_everything(seed):
import random
import numpy as np
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed % (2**32))
random.seed(seed)
def import_filename(filename):
spec = importlib.util.spec_from_file_location("mymodule", filename)
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return module
def delete_additional_ckpt(base_path, num_keep):
dirs = []
for d in os.listdir(base_path):
if d.startswith("checkpoint-"):
dirs.append(d)
num_tot = len(dirs)
if num_tot <= num_keep:
return
# ensure ckpt is sorted and delete the ealier!
del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
for d in del_dirs:
path_to_dir = osp.join(base_path, d)
if osp.exists(path_to_dir):
shutil.rmtree(path_to_dir)
def save_videos_from_pil(pil_images, path, fps=8):
import av
save_fmt = Path(path).suffix
os.makedirs(os.path.dirname(path), exist_ok=True)
width, height = pil_images[0].size
if save_fmt == ".mp4":
codec = "libx264"
container = av.open(path, "w")
stream = container.add_stream(codec, rate=fps)
stream.width = width
stream.height = height
for pil_image in pil_images:
# pil_image = Image.fromarray(image_arr).convert("RGB")
av_frame = av.VideoFrame.from_image(pil_image)
container.mux(stream.encode(av_frame))
container.mux(stream.encode())
container.close()
elif save_fmt == ".gif":
pil_images[0].save(
fp=path,
format="GIF",
append_images=pil_images[1:],
save_all=True,
duration=(1 / fps * 1000),
loop=0,
)
else:
raise ValueError("Unsupported file type. Use .mp4 or .gif.")
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
videos = rearrange(videos, "b c t h w -> t b c h w")
height, width = videos.shape[-2:]
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = (x * 255).numpy().astype(np.uint8)
x = Image.fromarray(x)
outputs.append(x)
os.makedirs(os.path.dirname(path), exist_ok=True)
save_videos_from_pil(outputs, path, fps)
def read_frames(video_path):
container = av.open(video_path)
video_stream = next(s for s in container.streams if s.type == "video")
frames = []
for packet in container.demux(video_stream):
for frame in packet.decode():
image = Image.frombytes(
"RGB",
(frame.width, frame.height),
frame.to_rgb().to_ndarray(),
)
frames.append(image)
return frames
def get_fps(video_path):
container = av.open(video_path)
video_stream = next(s for s in container.streams if s.type == "video")
fps = video_stream.average_rate
container.close()
return fps |