STAR / inference_utils.py
xierui.0097
Add application file
f0e9666
raw
history blame
4.72 kB
import os
import subprocess
import tempfile
import cv2
import torch
from PIL import Image
from typing import Mapping
from einops import rearrange
import numpy as np
import torchvision.transforms.functional as transforms_F
from video_to_video.utils.logger import get_logger
logger = get_logger()
def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
video = video.mul_(std).add_(mean)
video.clamp_(0, 1)
video = video * 255.0
images = rearrange(video, 'b c f h w -> b f h w c')[0]
return images
def preprocess(input_frames):
out_frame_list = []
for pointer in range(len(input_frames)):
frame = input_frames[pointer]
frame = frame[:, :, ::-1]
frame = Image.fromarray(frame.astype('uint8')).convert('RGB')
frame = transforms_F.to_tensor(frame)
out_frame_list.append(frame)
out_frames = torch.stack(out_frame_list, dim=0)
out_frames.clamp_(0, 1)
mean = out_frames.new_tensor([0.5, 0.5, 0.5]).view(-1)
std = out_frames.new_tensor([0.5, 0.5, 0.5]).view(-1)
out_frames.sub_(mean.view(1, -1, 1, 1)).div_(std.view(1, -1, 1, 1))
return out_frames
def adjust_resolution(h, w, up_scale):
if h*up_scale < 720:
up_s = 720/h
target_h = int(up_s*h//2*2)
target_w = int(up_s*w//2*2)
elif h*w*up_scale*up_scale > 1280*2048:
up_s = np.sqrt(1280*2048/(h*w))
target_h = int(up_s*h//2*2)
target_w = int(up_s*w//2*2)
else:
target_h = int(up_scale*h//2*2)
target_w = int(up_scale*w//2*2)
return (target_h, target_w)
def make_mask_cond(in_f_num, interp_f_num):
mask_cond = []
interp_cond = [-1 for _ in range(interp_f_num)]
for i in range(in_f_num):
mask_cond.append(i)
if i != in_f_num - 1:
mask_cond += interp_cond
return mask_cond
def load_video(vid_path):
capture = cv2.VideoCapture(vid_path)
_fps = capture.get(cv2.CAP_PROP_FPS)
_total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
pointer = 0
frame_list = []
stride = 1
while len(frame_list) < _total_frame_num:
ret, frame = capture.read()
pointer += 1
if (not ret) or (frame is None):
break
if pointer >= _total_frame_num + 1:
break
if pointer % stride == 0:
frame_list.append(frame)
capture.release()
return frame_list, _fps
def save_video(video, save_dir, file_name, fps=16.0):
output_path = os.path.join(save_dir, file_name)
images = [(img.numpy()).astype('uint8') for img in video]
temp_dir = tempfile.mkdtemp()
for fid, frame in enumerate(images):
tpth = os.path.join(temp_dir, '%06d.png' % (fid + 1))
cv2.imwrite(tpth, frame[:, :, ::-1])
tmp_path = os.path.join(save_dir, 'tmp.mp4')
cmd = f'ffmpeg -y -f image2 -framerate {fps} -i {temp_dir}/%06d.png \
-vcodec libx264 -preset ultrafast -crf 0 -pix_fmt yuv420p {tmp_path}'
status, output = subprocess.getstatusoutput(cmd)
if status != 0:
logger.error('Save Video Error with {}'.format(output))
os.system(f'rm -rf {temp_dir}')
os.rename(tmp_path, output_path)
def collate_fn(data, device):
"""Prepare the input just before the forward function.
This method will move the tensors to the right device.
Usually this method does not need to be overridden.
Args:
data: The data out of the dataloader.
device: The device to move data to.
Returns: The processed data.
"""
from torch.utils.data.dataloader import default_collate
def get_class_name(obj):
return obj.__class__.__name__
if isinstance(data, dict) or isinstance(data, Mapping):
return type(data)({
k: collate_fn(v, device) if k != 'img_metas' else v
for k, v in data.items()
})
elif isinstance(data, (tuple, list)):
if 0 == len(data):
return torch.Tensor([])
if isinstance(data[0], (int, float)):
return default_collate(data).to(device)
else:
return type(data)(collate_fn(v, device) for v in data)
elif isinstance(data, np.ndarray):
if data.dtype.type is np.str_:
return data
else:
return collate_fn(torch.from_numpy(data), device)
elif isinstance(data, torch.Tensor):
return data.to(device)
elif isinstance(data, (bytes, str, int, float, bool, type(None))):
return data
else:
raise ValueError(f'Unsupported data type {type(data)}')