Spaces:
Running
Running
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import os | |
import os.path as osp | |
import sys | |
import cv2 | |
import glob | |
import math | |
import torch | |
import gzip | |
import copy | |
import time | |
import json | |
import pickle | |
import base64 | |
import imageio | |
import hashlib | |
import requests | |
import binascii | |
import zipfile | |
# import skvideo.io | |
import numpy as np | |
from io import BytesIO | |
import urllib.request | |
import torch.nn.functional as F | |
import torchvision.utils as tvutils | |
from multiprocessing.pool import ThreadPool as Pool | |
from einops import rearrange | |
from PIL import Image, ImageDraw, ImageFont | |
def gen_text_image(captions, text_size): | |
num_char = int(38 * (text_size / text_size)) | |
font_size = int(text_size / 20) | |
font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=font_size) | |
text_image_list = [] | |
for text in captions: | |
txt_img = Image.new("RGB", (text_size, text_size), color="white") | |
draw = ImageDraw.Draw(txt_img) | |
lines = "\n".join(text[start:start + num_char] for start in range(0, len(text), num_char)) | |
draw.text((0, 0), lines, fill="black", font=font) | |
txt_img = np.array(txt_img) | |
text_image_list.append(txt_img) | |
text_images = np.stack(text_image_list, axis=0) | |
text_images = torch.from_numpy(text_images) | |
return text_images | |
def save_video_refimg_and_text( | |
local_path, | |
ref_frame, | |
gen_video, | |
captions, | |
mean=[0.5, 0.5, 0.5], | |
std=[0.5, 0.5, 0.5], | |
text_size=256, | |
nrow=4, | |
save_fps=8, | |
retry=5): | |
''' | |
gen_video: BxCxFxHxW | |
''' | |
nrow = max(int(gen_video.size(0) / 2), 1) | |
vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
text_images = gen_text_image(captions, text_size) # Tensor 8x256x256x3 | |
text_images = text_images.unsqueeze(1) # Tensor 8x1x256x256x3 | |
text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 8x16x256x256x3 | |
ref_frame = ref_frame.unsqueeze(2) | |
ref_frame = ref_frame.mul_(vid_std).add_(vid_mean) | |
ref_frame = ref_frame.repeat_interleave(repeats=gen_video.size(2), dim=2) # 8x16x256x256x3 | |
ref_frame.clamp_(0, 1) | |
ref_frame = ref_frame * 255.0 | |
ref_frame = rearrange(ref_frame, 'b c f h w -> b f h w c') | |
gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 | |
gen_video.clamp_(0, 1) | |
gen_video = gen_video * 255.0 | |
images = rearrange(gen_video, 'b c f h w -> b f h w c') | |
images = torch.cat([ref_frame, images, text_images], dim=3) | |
images = rearrange(images, '(r j) f h w c -> f (r h) (j w) c', r=nrow) | |
images = [(img.numpy()).astype('uint8') for img in images] | |
for _ in [None] * retry: | |
try: | |
if len(images) == 1: | |
local_path = local_path + '.png' | |
cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
else: | |
local_path = local_path + '.mp4' | |
frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) | |
os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) | |
for fid, frame in enumerate(images): | |
tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) | |
cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' | |
os.system(cmd); os.system(f'rm -rf {frame_dir}') | |
# os.system(f'rm -rf {local_path}') | |
exception = None | |
break | |
except Exception as e: | |
exception = e | |
continue | |
def save_i2vgen_video( | |
local_path, | |
image_id, | |
gen_video, | |
captions, | |
mean=[0.5, 0.5, 0.5], | |
std=[0.5, 0.5, 0.5], | |
text_size=256, | |
retry=5, | |
save_fps = 8 | |
): | |
''' | |
Save both the generated video and the input conditions. | |
''' | |
vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
text_images = gen_text_image(captions, text_size) # Tensor 1x256x256x3 | |
text_images = text_images.unsqueeze(1) # Tensor 1x1x256x256x3 | |
text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 1x16x256x256x3 | |
image_id = image_id.unsqueeze(2) # B, C, F, H, W | |
image_id = image_id.repeat_interleave(repeats=gen_video.size(2), dim=2) # 1x3x32x256x448 | |
image_id = image_id.mul_(vid_std).add_(vid_mean) # 32x3x256x448 | |
image_id.clamp_(0, 1) | |
image_id = image_id * 255.0 | |
image_id = rearrange(image_id, 'b c f h w -> b f h w c') | |
gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 | |
gen_video.clamp_(0, 1) | |
gen_video = gen_video * 255.0 | |
images = rearrange(gen_video, 'b c f h w -> b f h w c') | |
images = torch.cat([image_id, images, text_images], dim=3) | |
images = images[0] | |
images = [(img.numpy()).astype('uint8') for img in images] | |
exception = None | |
for _ in [None] * retry: | |
try: | |
frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) | |
os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) | |
for fid, frame in enumerate(images): | |
tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) | |
cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' | |
os.system(cmd); os.system(f'rm -rf {frame_dir}') | |
break | |
except Exception as e: | |
exception = e | |
continue | |
if exception is not None: | |
raise exception | |
def save_i2vgen_video_safe( | |
local_path, | |
gen_video, | |
captions, | |
mean=[0.5, 0.5, 0.5], | |
std=[0.5, 0.5, 0.5], | |
text_size=256, | |
retry=5, | |
save_fps = 8 | |
): | |
''' | |
Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame. | |
''' | |
vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 | |
gen_video.clamp_(0, 1) | |
gen_video = gen_video * 255.0 | |
images = rearrange(gen_video, 'b c f h w -> b f h w c') | |
images = images[0] | |
images = [(img.numpy()).astype('uint8') for img in images] | |
num_image = len(images) | |
exception = None | |
for _ in [None] * retry: | |
try: | |
if num_image == 1: | |
local_path = local_path + '.png' | |
cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
else: | |
writer = imageio.get_writer(local_path, fps=save_fps, codec='libx264', quality=8) | |
for fid, frame in enumerate(images): | |
if fid == num_image-1: # Fix known bugs. | |
ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size) | |
if ratio > 0.4: continue | |
writer.append_data(frame) | |
writer.close() | |
break | |
except Exception as e: | |
exception = e | |
continue | |
if exception is not None: | |
raise exception | |
def save_t2vhigen_video_safe( | |
local_path, | |
gen_video, | |
captions, | |
mean=[0.5, 0.5, 0.5], | |
std=[0.5, 0.5, 0.5], | |
text_size=256, | |
retry=5, | |
save_fps = 8 | |
): | |
''' | |
Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame. | |
''' | |
vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 | |
gen_video.clamp_(0, 1) | |
gen_video = gen_video * 255.0 | |
images = rearrange(gen_video, 'b c f h w -> b f h w c') | |
images = images[0] | |
images = [(img.numpy()).astype('uint8') for img in images] | |
num_image = len(images) | |
exception = None | |
for _ in [None] * retry: | |
try: | |
if num_image == 1: | |
local_path = local_path + '.png' | |
cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
else: | |
frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) | |
os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) | |
for fid, frame in enumerate(images): | |
if fid == num_image-1: # Fix known bugs. | |
ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size) | |
if ratio > 0.4: continue | |
tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) | |
cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' | |
os.system(cmd) | |
os.system(f'rm -rf {frame_dir}') | |
break | |
except Exception as e: | |
exception = e | |
continue | |
if exception is not None: | |
raise exception | |
def save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_tensor, model_kwargs, source_imgs, | |
mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], nrow=8, retry=5, save_fps=8): | |
mean=torch.tensor(mean,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw | |
std=torch.tensor(std,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw | |
video_tensor = video_tensor.mul_(std).add_(mean) #### unnormalize back to [0,1] | |
video_tensor.clamp_(0, 1) | |
b, c, n, h, w = video_tensor.shape | |
source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w)) | |
source_imgs = source_imgs.cpu() | |
model_kwargs_channel3 = {} | |
for key, conditions in model_kwargs[0].items(): | |
if conditions.size(1) == 1: | |
conditions = torch.cat([conditions, conditions, conditions], dim=1) | |
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) | |
if conditions.size(1) == 2: | |
conditions = torch.cat([conditions, conditions[:,:1,]], dim=1) | |
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) | |
elif conditions.size(1) == 3: | |
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) | |
elif conditions.size(1) == 4: # means it is a mask. | |
color = ((conditions[:, 0:3] + 1.)/2.) # .astype(np.float32) | |
alpha = conditions[:, 3:4] # .astype(np.float32) | |
conditions = color * alpha + 1.0 * (1.0 - alpha) | |
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) | |
model_kwargs_channel3[key] = conditions.cpu() if conditions.is_cuda else conditions | |
# filename = rand_name(suffix='.gif') | |
for _ in [None] * retry: | |
try: | |
vid_gif = rearrange(video_tensor, '(i j) c f h w -> c f (i h) (j w)', i = nrow) | |
cons_list = [rearrange(con, '(i j) c f h w -> c f (i h) (j w)', i = nrow) for _, con in model_kwargs_channel3.items()] | |
vid_gif = torch.cat(cons_list + [vid_gif,], dim=3) | |
vid_gif = vid_gif.permute(1,2,3,0) | |
images = vid_gif * 255.0 | |
images = [(img.numpy()).astype('uint8') for img in images] | |
if len(images) == 1: | |
local_path = local_path.replace('.mp4', '.png') | |
cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
# bucket.put_object_from_file(oss_key, local_path) | |
else: | |
outputs = [] | |
for image_name in images: | |
x = Image.fromarray(image_name) | |
outputs.append(x) | |
from pathlib import Path | |
save_fmt = Path(local_path).suffix | |
if save_fmt == ".mp4": | |
with imageio.get_writer(local_path, fps=save_fps) as writer: | |
for img in outputs: | |
img_array = np.array(img) # Convert PIL Image to numpy array | |
writer.append_data(img_array) | |
elif save_fmt == ".gif": | |
outputs[0].save( | |
fp=local_path, | |
format="GIF", | |
append_images=outputs[1:], | |
save_all=True, | |
duration=(1 / save_fps * 1000), | |
loop=0, | |
) | |
else: | |
raise ValueError("Unsupported file type. Use .mp4 or .gif.") | |
# fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
# fps = save_fps | |
# image = images[0] | |
# media_writer = cv2.VideoWriter(local_path, fourcc, fps, (image.shape[1],image.shape[0])) | |
# for image_name in images: | |
# im = image_name[:,:,::-1] | |
# media_writer.write(im) | |
# media_writer.release() | |
exception = None | |
break | |
except Exception as e: | |
exception = e | |
continue | |
if exception is not None: | |
print('save video to {} failed, error: {}'.format(local_path, exception), flush=True) | |