Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import os
import os.path as osp
import sys
import cv2
import glob
import math
import torch
import gzip
import copy
import time
import json
import pickle
import base64
import imageio
import hashlib
import requests
import binascii
import zipfile
# import skvideo.io
import numpy as np
from io import BytesIO
import urllib.request
import torch.nn.functional as F
import torchvision.utils as tvutils
from multiprocessing.pool import ThreadPool as Pool
from einops import rearrange
from PIL import Image, ImageDraw, ImageFont
def gen_text_image(captions, text_size):
num_char = int(38 * (text_size / text_size))
font_size = int(text_size / 20)
font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=font_size)
text_image_list = []
for text in captions:
txt_img = Image.new("RGB", (text_size, text_size), color="white")
draw = ImageDraw.Draw(txt_img)
lines = "\n".join(text[start:start + num_char] for start in range(0, len(text), num_char))
draw.text((0, 0), lines, fill="black", font=font)
txt_img = np.array(txt_img)
text_image_list.append(txt_img)
text_images = np.stack(text_image_list, axis=0)
text_images = torch.from_numpy(text_images)
return text_images
@torch.no_grad()
def save_video_refimg_and_text(
local_path,
ref_frame,
gen_video,
captions,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
text_size=256,
nrow=4,
save_fps=8,
retry=5):
'''
gen_video: BxCxFxHxW
'''
nrow = max(int(gen_video.size(0) / 2), 1)
vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
text_images = gen_text_image(captions, text_size) # Tensor 8x256x256x3
text_images = text_images.unsqueeze(1) # Tensor 8x1x256x256x3
text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 8x16x256x256x3
ref_frame = ref_frame.unsqueeze(2)
ref_frame = ref_frame.mul_(vid_std).add_(vid_mean)
ref_frame = ref_frame.repeat_interleave(repeats=gen_video.size(2), dim=2) # 8x16x256x256x3
ref_frame.clamp_(0, 1)
ref_frame = ref_frame * 255.0
ref_frame = rearrange(ref_frame, 'b c f h w -> b f h w c')
gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384
gen_video.clamp_(0, 1)
gen_video = gen_video * 255.0
images = rearrange(gen_video, 'b c f h w -> b f h w c')
images = torch.cat([ref_frame, images, text_images], dim=3)
images = rearrange(images, '(r j) f h w c -> f (r h) (j w) c', r=nrow)
images = [(img.numpy()).astype('uint8') for img in images]
for _ in [None] * retry:
try:
if len(images) == 1:
local_path = local_path + '.png'
cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
else:
local_path = local_path + '.mp4'
frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path)))
os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True)
for fid, frame in enumerate(images):
tpth = os.path.join(frame_dir, '%04d.png' % (fid+1))
cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}'
os.system(cmd); os.system(f'rm -rf {frame_dir}')
# os.system(f'rm -rf {local_path}')
exception = None
break
except Exception as e:
exception = e
continue
@torch.no_grad()
def save_i2vgen_video(
local_path,
image_id,
gen_video,
captions,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
text_size=256,
retry=5,
save_fps = 8
):
'''
Save both the generated video and the input conditions.
'''
vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
text_images = gen_text_image(captions, text_size) # Tensor 1x256x256x3
text_images = text_images.unsqueeze(1) # Tensor 1x1x256x256x3
text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 1x16x256x256x3
image_id = image_id.unsqueeze(2) # B, C, F, H, W
image_id = image_id.repeat_interleave(repeats=gen_video.size(2), dim=2) # 1x3x32x256x448
image_id = image_id.mul_(vid_std).add_(vid_mean) # 32x3x256x448
image_id.clamp_(0, 1)
image_id = image_id * 255.0
image_id = rearrange(image_id, 'b c f h w -> b f h w c')
gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384
gen_video.clamp_(0, 1)
gen_video = gen_video * 255.0
images = rearrange(gen_video, 'b c f h w -> b f h w c')
images = torch.cat([image_id, images, text_images], dim=3)
images = images[0]
images = [(img.numpy()).astype('uint8') for img in images]
exception = None
for _ in [None] * retry:
try:
frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path)))
os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True)
for fid, frame in enumerate(images):
tpth = os.path.join(frame_dir, '%04d.png' % (fid+1))
cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}'
os.system(cmd); os.system(f'rm -rf {frame_dir}')
break
except Exception as e:
exception = e
continue
if exception is not None:
raise exception
@torch.no_grad()
def save_i2vgen_video_safe(
local_path,
gen_video,
captions,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
text_size=256,
retry=5,
save_fps = 8
):
'''
Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame.
'''
vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384
gen_video.clamp_(0, 1)
gen_video = gen_video * 255.0
images = rearrange(gen_video, 'b c f h w -> b f h w c')
images = images[0]
images = [(img.numpy()).astype('uint8') for img in images]
num_image = len(images)
exception = None
for _ in [None] * retry:
try:
if num_image == 1:
local_path = local_path + '.png'
cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
else:
writer = imageio.get_writer(local_path, fps=save_fps, codec='libx264', quality=8)
for fid, frame in enumerate(images):
if fid == num_image-1: # Fix known bugs.
ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size)
if ratio > 0.4: continue
writer.append_data(frame)
writer.close()
break
except Exception as e:
exception = e
continue
if exception is not None:
raise exception
@torch.no_grad()
def save_t2vhigen_video_safe(
local_path,
gen_video,
captions,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
text_size=256,
retry=5,
save_fps = 8
):
'''
Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame.
'''
vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384
gen_video.clamp_(0, 1)
gen_video = gen_video * 255.0
images = rearrange(gen_video, 'b c f h w -> b f h w c')
images = images[0]
images = [(img.numpy()).astype('uint8') for img in images]
num_image = len(images)
exception = None
for _ in [None] * retry:
try:
if num_image == 1:
local_path = local_path + '.png'
cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
else:
frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path)))
os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True)
for fid, frame in enumerate(images):
if fid == num_image-1: # Fix known bugs.
ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size)
if ratio > 0.4: continue
tpth = os.path.join(frame_dir, '%04d.png' % (fid+1))
cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}'
os.system(cmd)
os.system(f'rm -rf {frame_dir}')
break
except Exception as e:
exception = e
continue
if exception is not None:
raise exception
@torch.no_grad()
def save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_tensor, model_kwargs, source_imgs,
mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], nrow=8, retry=5, save_fps=8):
mean=torch.tensor(mean,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw
std=torch.tensor(std,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw
video_tensor = video_tensor.mul_(std).add_(mean) #### unnormalize back to [0,1]
video_tensor.clamp_(0, 1)
b, c, n, h, w = video_tensor.shape
source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w))
source_imgs = source_imgs.cpu()
model_kwargs_channel3 = {}
for key, conditions in model_kwargs[0].items():
if conditions.size(1) == 1:
conditions = torch.cat([conditions, conditions, conditions], dim=1)
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
if conditions.size(1) == 2:
conditions = torch.cat([conditions, conditions[:,:1,]], dim=1)
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
elif conditions.size(1) == 3:
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
elif conditions.size(1) == 4: # means it is a mask.
color = ((conditions[:, 0:3] + 1.)/2.) # .astype(np.float32)
alpha = conditions[:, 3:4] # .astype(np.float32)
conditions = color * alpha + 1.0 * (1.0 - alpha)
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
model_kwargs_channel3[key] = conditions.cpu() if conditions.is_cuda else conditions
# filename = rand_name(suffix='.gif')
for _ in [None] * retry:
try:
vid_gif = rearrange(video_tensor, '(i j) c f h w -> c f (i h) (j w)', i = nrow)
cons_list = [rearrange(con, '(i j) c f h w -> c f (i h) (j w)', i = nrow) for _, con in model_kwargs_channel3.items()]
vid_gif = torch.cat(cons_list + [vid_gif,], dim=3)
vid_gif = vid_gif.permute(1,2,3,0)
images = vid_gif * 255.0
images = [(img.numpy()).astype('uint8') for img in images]
if len(images) == 1:
local_path = local_path.replace('.mp4', '.png')
cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
# bucket.put_object_from_file(oss_key, local_path)
else:
outputs = []
for image_name in images:
x = Image.fromarray(image_name)
outputs.append(x)
from pathlib import Path
save_fmt = Path(local_path).suffix
if save_fmt == ".mp4":
with imageio.get_writer(local_path, fps=save_fps) as writer:
for img in outputs:
img_array = np.array(img) # Convert PIL Image to numpy array
writer.append_data(img_array)
elif save_fmt == ".gif":
outputs[0].save(
fp=local_path,
format="GIF",
append_images=outputs[1:],
save_all=True,
duration=(1 / save_fps * 1000),
loop=0,
)
else:
raise ValueError("Unsupported file type. Use .mp4 or .gif.")
# fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# fps = save_fps
# image = images[0]
# media_writer = cv2.VideoWriter(local_path, fourcc, fps, (image.shape[1],image.shape[0]))
# for image_name in images:
# im = image_name[:,:,::-1]
# media_writer.write(im)
# media_writer.release()
exception = None
break
except Exception as e:
exception = e
continue
if exception is not None:
print('save video to {} failed, error: {}'.format(local_path, exception), flush=True)