tokenflow / utils.py
multimodalart's picture
revert to working state
88a2ed3
raw
history blame contribute delete
No virus
3.97 kB
from pathlib import Path
from PIL import Image
import torch
import yaml
import math
import torchvision.transforms as T
from torchvision.io import read_video,write_video
import os
import random
import numpy as np
from torchvision.io import write_video
# from kornia.filters import joint_bilateral_blur
from kornia.geometry.transform import remap
from kornia.utils.grid import create_meshgrid
import cv2
def save_video_frames(video_path, img_size=(512,512)):
video, _, _ = read_video(video_path, output_format="TCHW")
# rotate video -90 degree if video is .mov format. this is a weird bug in torchvision
if video_path.endswith('.mov'):
video = T.functional.rotate(video, -90)
video_name = Path(video_path).stem
os.makedirs(f'data/{video_name}', exist_ok=True)
for i in range(len(video)):
ind = str(i).zfill(5)
image = T.ToPILImage()(video[i])
image_resized = image.resize((img_size), resample=Image.Resampling.LANCZOS)
image_resized.save(f'data/{video_name}/{ind}.png')
def video_to_frames(video_path, img_size=(512,512)):
video, _, _ = read_video(video_path, output_format="TCHW")
# rotate video -90 degree if video is .mov format. this is a weird bug in torchvision
if video_path.endswith('.mov'):
video = T.functional.rotate(video, -90)
video_name = Path(video_path).stem
# os.makedirs(f'data/{video_name}', exist_ok=True)
frames = []
for i in range(len(video)):
ind = str(i).zfill(5)
image = T.ToPILImage()(video[i])
image_resized = image.resize((img_size), resample=Image.Resampling.LANCZOS)
# image_resized.save(f'data/{video_name}/{ind}.png')
frames.append(image_resized)
return frames
def add_dict_to_yaml_file(file_path, key, value):
data = {}
# If the file already exists, load its contents into the data dictionary
if os.path.exists(file_path):
with open(file_path, 'r') as file:
data = yaml.safe_load(file)
# Add or update the key-value pair
data[key] = value
# Save the data back to the YAML file
with open(file_path, 'w') as file:
yaml.dump(data, file)
def isinstance_str(x: object, cls_name: str):
"""
Checks whether x has any class *named* cls_name in its ancestry.
Doesn't require access to the class's implementation.
Useful for patching!
"""
for _cls in x.__class__.__mro__:
if _cls.__name__ == cls_name:
return True
return False
def batch_cosine_sim(x, y):
if type(x) is list:
x = torch.cat(x, dim=0)
if type(y) is list:
y = torch.cat(y, dim=0)
x = x / x.norm(dim=-1, keepdim=True)
y = y / y.norm(dim=-1, keepdim=True)
similarity = x @ y.T
return similarity
def load_imgs(data_path, n_frames, device='cuda', pil=False):
imgs = []
pils = []
for i in range(n_frames):
img_path = os.path.join(data_path, "%05d.jpg" % i)
if not os.path.exists(img_path):
img_path = os.path.join(data_path, "%05d.png" % i)
img_pil = Image.open(img_path)
pils.append(img_pil)
img = T.ToTensor()(img_pil).unsqueeze(0)
imgs.append(img)
if pil:
return torch.cat(imgs).to(device), pils
return torch.cat(imgs).to(device)
def save_video(raw_frames, save_path, fps=10):
video_codec = "libx264"
video_options = {
"crf": "18", # Constant Rate Factor (lower value = higher quality, 18 is a good balance)
"preset": "slow", # Encoding preset (e.g., ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow)
}
frames = (raw_frames * 255).to(torch.uint8).cpu().permute(0, 2, 3, 1)
write_video(save_path, frames, fps=fps, video_codec=video_codec, options=video_options)
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)