Spaces:
Runtime error
Runtime error
import os | |
import imageio | |
import numpy as np | |
from typing import Union, Optional | |
import torch | |
import torchvision | |
import torch.distributed as dist | |
from tqdm import tqdm | |
from einops import rearrange | |
import cv2 | |
import math | |
import moviepy.editor as mpy | |
from PIL import Image | |
# We recommend to use the following affinity score(motion magnitude) | |
# Also encourage to try to construct different score by yourself | |
# RANGE_LIST = [ | |
# [1.0, 0.9, 0.85, 0.85, 0.85, 0.8], # 0 Small Motion | |
# [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], # Moderate Motion | |
# [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5], # Large Motion | |
# # [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6], # Large Motion | |
# # [1.0, 0.65, 0.6], # candidate moderate | |
# # [1.0, 0.65, 0.6, 0.6, 0.6, 0.5, 0.5, 0.5, 0.5, 0.4], # candidate large | |
# [1.0 , 0.9 , 0.85, 0.85, 0.85, 0.8 , 0.8 , 0.8 , 0.8 , 0.8 , 0.8 , 0.8 , 0.85, 0.85, 0.9 , 1.0 ], # Loop | |
# [1.0 , 0.8 , 0.8 , 0.8 , 0.79, 0.78, 0.75, 0.75, 0.75, 0.75, 0.75, 0.78, 0.79, 0.8 , 0.8 , 1.0 ], # Loop | |
# [1.0 , 0.8 , 0.7 , 0.7 , 0.7 , 0.7 , 0.6 , 0.5 , 0.5 , 0.6 , 0.7 , 0.7 , 0.7 , 0.7 , 0.8 , 1.0 ], # Loop | |
# # [1.0], # Static | |
# # [0], | |
# # [0.6, 0.5, 0.5, 0.45, 0.45, 0.4], # Style Transfer Test | |
# # [0.4, 0.3, 0.3, 0.25, 0.25, 0.2], # Style Transfer | |
# [0.5, 0.2], # Style Transfer Large Motion | |
# [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # Style Transfer Moderate Motion | |
# [0.5, 0.4, 0.4, 0.4, 0.35, 0.3], # Style Transfer Candidate Small Motion | |
# ] | |
RANGE_LIST = [ | |
[0.5, 0.4, 0.4, 0.4, 0.35, 0.3], # Style Transfer Candidate Small Motion | |
[0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # Style Transfer Moderate Motion | |
[0.5, 0.2], # Style Transfer Large Motion | |
] | |
def zero_rank_print(s): | |
if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) | |
def save_videos_mp4(video: torch.Tensor, path: str, fps: int=8): | |
video = rearrange(video, "b c t h w -> t b c h w") | |
num_frames, batch_size, channels, height, width = video.shape | |
assert batch_size == 1,\ | |
'Only support batch size == 1' | |
video = video.squeeze(1) | |
video = rearrange(video, "t c h w -> t h w c") | |
def make_frame(t): | |
frame_tensor = video[int(t * fps)] | |
frame_np = (frame_tensor * 255).numpy().astype('uint8') | |
return frame_np | |
clip = mpy.VideoClip(make_frame, duration=num_frames / fps) | |
clip.write_videofile(path, fps=fps, codec='libx264') | |
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
outputs = [] | |
for x in videos: | |
x = torchvision.utils.make_grid(x, nrow=n_rows) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = torch.clamp((x * 255), 0, 255).numpy().astype(np.uint8) | |
outputs.append(x) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
imageio.mimsave(path, outputs, fps=fps) | |
# DDIM Inversion | |
def init_prompt(prompt, pipeline): | |
uncond_input = pipeline.tokenizer( | |
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, | |
return_tensors="pt" | |
) | |
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] | |
text_input = pipeline.tokenizer( | |
[prompt], | |
padding="max_length", | |
max_length=pipeline.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] | |
context = torch.cat([uncond_embeddings, text_embeddings]) | |
return context | |
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, | |
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): | |
timestep, next_timestep = min( | |
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep | |
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod | |
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] | |
beta_prod_t = 1 - alpha_prod_t | |
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 | |
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output | |
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction | |
return next_sample | |
def get_noise_pred_single(latents, t, context, unet): | |
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] | |
return noise_pred | |
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): | |
context = init_prompt(prompt, pipeline) | |
uncond_embeddings, cond_embeddings = context.chunk(2) | |
all_latent = [latent] | |
latent = latent.clone().detach() | |
for i in tqdm(range(num_inv_steps)): | |
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] | |
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) | |
latent = next_step(noise_pred, t, latent, ddim_scheduler) | |
all_latent.append(latent) | |
return all_latent | |
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): | |
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) | |
return ddim_latents | |
def prepare_mask_coef(video_length:int, cond_frame:int, sim_range:list=[0.2, 1.0]): | |
assert len(sim_range) == 2, \ | |
'sim_range should has the length of 2, including the min and max similarity' | |
assert video_length > 1, \ | |
'video_length should be greater than 1' | |
assert video_length > cond_frame,\ | |
'video_length should be greater than cond_frame' | |
diff = abs(sim_range[0] - sim_range[1]) / (video_length - 1) | |
coef = [1.0] * video_length | |
for f in range(video_length): | |
f_diff = diff * abs(cond_frame - f) | |
f_diff = 1 - f_diff | |
coef[f] *= f_diff | |
return coef | |
def prepare_mask_coef_by_score(video_shape: list, cond_frame_idx: list, sim_range: list = [0.2, 1.0], | |
statistic: list = [1, 100], coef_max: int = 0.98, score: Optional[torch.Tensor] = None): | |
''' | |
the shape of video_data is (b f c h w) | |
cond_frame_idx is a list, with length of batch_size | |
the shape of statistic is (f 2) | |
the shape of score is (b f) | |
the shape of coef is (b f) | |
''' | |
assert len(video_shape) == 2, \ | |
f'the shape of video_shape should be (b f c h w), but now get {len(video_shape.shape)} channels' | |
batch_size, frame_num = video_shape[0], video_shape[1] | |
score = score.permute(0, 2, 1).squeeze(0) | |
# list -> b 1 | |
cond_fram_mat = torch.tensor(cond_frame_idx).unsqueeze(-1) | |
statistic = torch.tensor(statistic) | |
# (f 2) -> (b f 2) | |
statistic = statistic.repeat(batch_size, 1, 1) | |
# shape of order (b f), shape of cond_mat (b f) | |
order = torch.arange(0, frame_num, 1) | |
order = order.repeat(batch_size, 1) | |
cond_mat = torch.ones((batch_size, frame_num)) * cond_fram_mat | |
order = abs(order - cond_mat) | |
statistic = statistic[:,order.to(torch.long)][0,:,:,:] | |
# score (b f) max_s (b f 1) | |
max_stats = torch.max(statistic, dim=2).values.to(dtype=score.dtype) | |
min_stats = torch.min(statistic, dim=2).values.to(dtype=score.dtype) | |
score[score > max_stats] = max_stats[score > max_stats] * 0.95 | |
score[score < min_stats] = min_stats[score < min_stats] | |
eps = 1e-10 | |
coef = 1 - abs((score / (max_stats + eps)) * (max(sim_range) - min(sim_range))) | |
indices = torch.arange(coef.shape[0]).unsqueeze(1) | |
coef[indices, cond_fram_mat] = 1.0 | |
return coef | |
def prepare_mask_coef_by_statistics(video_length: int, cond_frame: int, sim_range: int, | |
coef: Optional[list] = None): | |
""" | |
coef: User defined coef, if passed, `sim_range` index will be ignored. This is useful | |
for defining custom style transform coef for different models. | |
""" | |
assert video_length > 1, \ | |
'video_length should be greater than 1' | |
assert video_length > cond_frame,\ | |
'video_length should be greater than cond_frame' | |
# Recommend index: 13 | |
# range_list = [ | |
# # [0.8, 0.8, 0.7, 0.6], | |
# [1.0, 0.8, 0.7, 0.6], | |
# [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5], | |
# [1.0, 0.9, 0.85, 0.85, 0.85, 0.8], # 0 | |
# [1.0, 0.9, 0.8, 0.7], | |
# [1.0, 0.8, 0.7, 0.6, 0.7, 0.6], | |
# [1.0, 0.9, 0.85], | |
# # [1.0, 0.9, 0.7, 0.5, 0.3, 0.2], | |
# # [1.0, 0.8, 0.6, 0.4], | |
# # [1.0, 0.65, 0.6], # 1 | |
# [1.0, 0.6, 0.4], # 2 | |
# [1.0, 0.2, 0.2], | |
# # [1.0, 0.8, 0.6, 0.6, 0.5, 0.5, 0.4], | |
# # [1.0, 0.9, 0.9, 0.9, 0.9, 0.8], | |
# # [1.0, 0.65, 0.6, 0.6, 0.5, 0.5, 0.4], | |
# # [1.0, 0.9, 0.9, 0.9, 0.7, 0.7, 0.6, 0.5, 0.4], | |
# [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], # 4 style_transfer | |
# [1.0, 0.9, 0.9], | |
# [0.8, 0.7, 0.6], | |
# [0.8, 0.8, 0.8, 0.8, 0.7], | |
# [0.9, 0.6, 0.6, 0.6, 0.5, 0.4, 0.2], | |
# # [1.0, 0.91, 0.9, 0.89, 0.88, 0.87], | |
# # [1.0, 0.7, 0.65, 0.65, 0.65, 0.65, 0.6], | |
# # [1.0, 0.85, 0.9, 0.85, 0.9, 0.85], | |
# # [1.0, 0.8, 0.82, 0.84, 0.86, 0.88, 0.78, 0.82, 0.84], | |
# # [1.0], | |
# ] | |
range_list = RANGE_LIST | |
assert sim_range < len(range_list),\ | |
f'sim_range type{sim_range} not implemented' | |
if coef is None: | |
coef = range_list[sim_range] | |
coef = coef + ([coef[-1]] * (video_length - len(coef))) | |
order = [abs(i - cond_frame) for i in range(video_length)] | |
coef = [coef[order[i]] for i in range(video_length)] | |
return coef |