|
import os |
|
import shutil |
|
from enum import Enum |
|
|
|
import cv2 |
|
import einops |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import torchvision.transforms as T |
|
from blendmodes.blend import BlendType, blendLayers |
|
from PIL import Image |
|
from pytorch_lightning import seed_everything |
|
from safetensors.torch import load_file |
|
from skimage import exposure |
|
|
|
import src.import_util |
|
from ControlNet.annotator.canny import CannyDetector |
|
from ControlNet.annotator.hed import HEDdetector |
|
from ControlNet.annotator.util import HWC3 |
|
from ControlNet.cldm.model import create_model, load_state_dict |
|
from gmflow_module.gmflow.gmflow import GMFlow |
|
from flow.flow_utils import get_warped_and_mask |
|
from sd_model_cfg import model_dict |
|
from src.config import RerenderConfig |
|
from src.controller import AttentionControl |
|
from src.ddim_v_hacked import DDIMVSampler |
|
from src.img_util import find_flat_region, numpy2tensor |
|
from src.video_util import (frame_to_video, get_fps, get_frame_count, |
|
prepare_frames) |
|
|
|
import huggingface_hub |
|
|
|
REPO_NAME = 'Anonymous-sub/Rerender' |
|
|
|
huggingface_hub.hf_hub_download(REPO_NAME, |
|
'pexels-koolshooters-7322716.mp4', |
|
local_dir='videos') |
|
huggingface_hub.hf_hub_download( |
|
REPO_NAME, |
|
'pexels-antoni-shkraba-8048492-540x960-25fps.mp4', |
|
local_dir='videos') |
|
huggingface_hub.hf_hub_download( |
|
REPO_NAME, |
|
'pexels-cottonbro-studio-6649832-960x506-25fps.mp4', |
|
local_dir='videos') |
|
|
|
inversed_model_dict = dict() |
|
for k, v in model_dict.items(): |
|
inversed_model_dict[v] = k |
|
|
|
to_tensor = T.PILToTensor() |
|
blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18)) |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
class ProcessingState(Enum): |
|
NULL = 0 |
|
FIRST_IMG = 1 |
|
KEY_IMGS = 2 |
|
|
|
|
|
MAX_KEYFRAME = 8 |
|
|
|
|
|
class GlobalState: |
|
|
|
def __init__(self): |
|
self.sd_model = None |
|
self.ddim_v_sampler = None |
|
self.detector_type = None |
|
self.detector = None |
|
self.controller = None |
|
self.processing_state = ProcessingState.NULL |
|
flow_model = GMFlow( |
|
feature_channels=128, |
|
num_scales=1, |
|
upsample_factor=8, |
|
num_head=1, |
|
attention_type='swin', |
|
ffn_dim_expansion=4, |
|
num_transformer_layers=6, |
|
).to(device) |
|
|
|
checkpoint = torch.load('models/gmflow_sintel-0c07dcb3.pth', |
|
map_location=lambda storage, loc: storage) |
|
weights = checkpoint['model'] if 'model' in checkpoint else checkpoint |
|
flow_model.load_state_dict(weights, strict=False) |
|
flow_model.eval() |
|
self.flow_model = flow_model |
|
|
|
def update_controller(self, inner_strength, mask_period, cross_period, |
|
ada_period, warp_period): |
|
self.controller = AttentionControl(inner_strength, mask_period, |
|
cross_period, ada_period, |
|
warp_period) |
|
|
|
def update_sd_model(self, sd_model, control_type): |
|
if sd_model == self.sd_model: |
|
return |
|
self.sd_model = sd_model |
|
model = create_model('./ControlNet/models/cldm_v15.yaml').cpu() |
|
if control_type == 'HED': |
|
model.load_state_dict( |
|
load_state_dict(huggingface_hub.hf_hub_download( |
|
'lllyasviel/ControlNet', './models/control_sd15_hed.pth'), |
|
location=device)) |
|
elif control_type == 'canny': |
|
model.load_state_dict( |
|
load_state_dict(huggingface_hub.hf_hub_download( |
|
'lllyasviel/ControlNet', 'models/control_sd15_canny.pth'), |
|
location=device)) |
|
model.to(device) |
|
sd_model_path = model_dict[sd_model] |
|
if len(sd_model_path) > 0: |
|
repo_name = REPO_NAME |
|
|
|
if sd_model.count('/') == 1: |
|
repo_name = sd_model |
|
|
|
model_ext = os.path.splitext(sd_model_path)[1] |
|
downloaded_model = huggingface_hub.hf_hub_download( |
|
repo_name, sd_model_path) |
|
if model_ext == '.safetensors': |
|
model.load_state_dict(load_file(downloaded_model), |
|
strict=False) |
|
elif model_ext == '.ckpt' or model_ext == '.pth': |
|
model.load_state_dict( |
|
torch.load(downloaded_model)['state_dict'], strict=False) |
|
|
|
try: |
|
model.first_stage_model.load_state_dict(torch.load( |
|
huggingface_hub.hf_hub_download( |
|
'stabilityai/sd-vae-ft-mse-original', |
|
'vae-ft-mse-840000-ema-pruned.ckpt'))['state_dict'], |
|
strict=False) |
|
except Exception: |
|
print('Warning: We suggest you download the fine-tuned VAE', |
|
'otherwise the generation quality will be degraded') |
|
|
|
self.ddim_v_sampler = DDIMVSampler(model) |
|
|
|
def clear_sd_model(self): |
|
self.sd_model = None |
|
self.ddim_v_sampler = None |
|
if device == 'cuda': |
|
torch.cuda.empty_cache() |
|
|
|
def update_detector(self, control_type, canny_low=100, canny_high=200): |
|
if self.detector_type == control_type: |
|
return |
|
if control_type == 'HED': |
|
self.detector = HEDdetector() |
|
elif control_type == 'canny': |
|
canny_detector = CannyDetector() |
|
low_threshold = canny_low |
|
high_threshold = canny_high |
|
|
|
def apply_canny(x): |
|
return canny_detector(x, low_threshold, high_threshold) |
|
|
|
self.detector = apply_canny |
|
|
|
|
|
global_state = GlobalState() |
|
global_video_path = None |
|
video_frame_count = None |
|
|
|
|
|
def create_cfg(input_path, prompt, image_resolution, control_strength, |
|
color_preserve, left_crop, right_crop, top_crop, bottom_crop, |
|
control_type, low_threshold, high_threshold, ddim_steps, scale, |
|
seed, sd_model, a_prompt, n_prompt, interval, keyframe_count, |
|
x0_strength, use_constraints, cross_start, cross_end, |
|
style_update_freq, warp_start, warp_end, mask_start, mask_end, |
|
ada_start, ada_end, mask_strength, inner_strength, |
|
smooth_boundary): |
|
use_warp = 'shape-aware fusion' in use_constraints |
|
use_mask = 'pixel-aware fusion' in use_constraints |
|
use_ada = 'color-aware AdaIN' in use_constraints |
|
|
|
if not use_warp: |
|
warp_start = 1 |
|
warp_end = 0 |
|
|
|
if not use_mask: |
|
mask_start = 1 |
|
mask_end = 0 |
|
|
|
if not use_ada: |
|
ada_start = 1 |
|
ada_end = 0 |
|
|
|
input_name = os.path.split(input_path)[-1].split('.')[0] |
|
frame_count = 2 + keyframe_count * interval |
|
cfg = RerenderConfig() |
|
cfg.create_from_parameters( |
|
input_path, |
|
os.path.join('result', input_name, 'blend.mp4'), |
|
prompt, |
|
a_prompt=a_prompt, |
|
n_prompt=n_prompt, |
|
frame_count=frame_count, |
|
interval=interval, |
|
crop=[left_crop, right_crop, top_crop, bottom_crop], |
|
sd_model=sd_model, |
|
ddim_steps=ddim_steps, |
|
scale=scale, |
|
control_type=control_type, |
|
control_strength=control_strength, |
|
canny_low=low_threshold, |
|
canny_high=high_threshold, |
|
seed=seed, |
|
image_resolution=image_resolution, |
|
x0_strength=x0_strength, |
|
style_update_freq=style_update_freq, |
|
cross_period=(cross_start, cross_end), |
|
warp_period=(warp_start, warp_end), |
|
mask_period=(mask_start, mask_end), |
|
ada_period=(ada_start, ada_end), |
|
mask_strength=mask_strength, |
|
inner_strength=inner_strength, |
|
smooth_boundary=smooth_boundary, |
|
color_preserve=color_preserve) |
|
return cfg |
|
|
|
|
|
def cfg_to_input(filename): |
|
|
|
cfg = RerenderConfig() |
|
cfg.create_from_path(filename) |
|
keyframe_count = (cfg.frame_count - 2) // cfg.interval |
|
use_constraints = [ |
|
'shape-aware fusion', 'pixel-aware fusion', 'color-aware AdaIN' |
|
] |
|
|
|
sd_model = inversed_model_dict.get(cfg.sd_model, 'Stable Diffusion 1.5') |
|
|
|
args = [ |
|
cfg.input_path, cfg.prompt, cfg.image_resolution, cfg.control_strength, |
|
cfg.color_preserve, *cfg.crop, cfg.control_type, cfg.canny_low, |
|
cfg.canny_high, cfg.ddim_steps, cfg.scale, cfg.seed, sd_model, |
|
cfg.a_prompt, cfg.n_prompt, cfg.interval, keyframe_count, |
|
cfg.x0_strength, use_constraints, *cfg.cross_period, |
|
cfg.style_update_freq, *cfg.warp_period, *cfg.mask_period, |
|
*cfg.ada_period, cfg.mask_strength, cfg.inner_strength, |
|
cfg.smooth_boundary |
|
] |
|
return args |
|
|
|
|
|
def setup_color_correction(image): |
|
correction_target = cv2.cvtColor(np.asarray(image.copy()), |
|
cv2.COLOR_RGB2LAB) |
|
return correction_target |
|
|
|
|
|
def apply_color_correction(correction, original_image): |
|
image = Image.fromarray( |
|
cv2.cvtColor( |
|
exposure.match_histograms(cv2.cvtColor(np.asarray(original_image), |
|
cv2.COLOR_RGB2LAB), |
|
correction, |
|
channel_axis=2), |
|
cv2.COLOR_LAB2RGB).astype('uint8')) |
|
|
|
image = blendLayers(image, original_image, BlendType.LUMINOSITY) |
|
|
|
return image |
|
|
|
|
|
@torch.no_grad() |
|
def process(*args): |
|
first_frame = process1(*args) |
|
|
|
keypath = process2(*args) |
|
|
|
return first_frame, keypath |
|
|
|
|
|
@torch.no_grad() |
|
def process0(*args): |
|
global global_video_path |
|
global_video_path = args[0] |
|
return process(*args[1:]) |
|
|
|
|
|
@torch.no_grad() |
|
def process1(*args): |
|
|
|
global global_video_path |
|
cfg = create_cfg(global_video_path, *args) |
|
global global_state |
|
global_state.update_sd_model(cfg.sd_model, cfg.control_type) |
|
global_state.update_controller(cfg.inner_strength, cfg.mask_period, |
|
cfg.cross_period, cfg.ada_period, |
|
cfg.warp_period) |
|
global_state.update_detector(cfg.control_type, cfg.canny_low, |
|
cfg.canny_high) |
|
global_state.processing_state = ProcessingState.FIRST_IMG |
|
|
|
prepare_frames(cfg.input_path, cfg.input_dir, cfg.image_resolution, |
|
cfg.crop) |
|
|
|
ddim_v_sampler = global_state.ddim_v_sampler |
|
model = ddim_v_sampler.model |
|
detector = global_state.detector |
|
controller = global_state.controller |
|
model.control_scales = [cfg.control_strength] * 13 |
|
model.to(device) |
|
|
|
num_samples = 1 |
|
eta = 0.0 |
|
imgs = sorted(os.listdir(cfg.input_dir)) |
|
imgs = [os.path.join(cfg.input_dir, img) for img in imgs] |
|
|
|
model.cond_stage_model.device = device |
|
|
|
with torch.no_grad(): |
|
frame = cv2.imread(imgs[0]) |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
img = HWC3(frame) |
|
H, W, C = img.shape |
|
|
|
img_ = numpy2tensor(img) |
|
|
|
def generate_first_img(img_, strength): |
|
encoder_posterior = model.encode_first_stage(img_.to(device)) |
|
x0 = model.get_first_stage_encoding(encoder_posterior).detach() |
|
|
|
detected_map = detector(img) |
|
detected_map = HWC3(detected_map) |
|
|
|
control = torch.from_numpy( |
|
detected_map.copy()).float().to(device) / 255.0 |
|
control = torch.stack([control for _ in range(num_samples)], dim=0) |
|
control = einops.rearrange(control, 'b h w c -> b c h w').clone() |
|
cond = { |
|
'c_concat': [control], |
|
'c_crossattn': [ |
|
model.get_learned_conditioning( |
|
[cfg.prompt + ', ' + cfg.a_prompt] * num_samples) |
|
] |
|
} |
|
un_cond = { |
|
'c_concat': [control], |
|
'c_crossattn': |
|
[model.get_learned_conditioning([cfg.n_prompt] * num_samples)] |
|
} |
|
shape = (4, H // 8, W // 8) |
|
|
|
controller.set_task('initfirst') |
|
seed_everything(cfg.seed) |
|
|
|
samples, _ = ddim_v_sampler.sample( |
|
cfg.ddim_steps, |
|
num_samples, |
|
shape, |
|
cond, |
|
verbose=False, |
|
eta=eta, |
|
unconditional_guidance_scale=cfg.scale, |
|
unconditional_conditioning=un_cond, |
|
controller=controller, |
|
x0=x0, |
|
strength=strength) |
|
x_samples = model.decode_first_stage(samples) |
|
x_samples_np = ( |
|
einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + |
|
127.5).cpu().numpy().clip(0, 255).astype(np.uint8) |
|
return x_samples, x_samples_np |
|
|
|
|
|
|
|
if not cfg.color_preserve: |
|
first_strength = -1 |
|
else: |
|
first_strength = 1 - cfg.x0_strength |
|
|
|
x_samples, x_samples_np = generate_first_img(img_, first_strength) |
|
|
|
if not cfg.color_preserve: |
|
color_corrections = setup_color_correction( |
|
Image.fromarray(x_samples_np[0])) |
|
global_state.color_corrections = color_corrections |
|
img_ = apply_color_correction(color_corrections, |
|
Image.fromarray(img)) |
|
img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 |
|
x_samples, x_samples_np = generate_first_img( |
|
img_, 1 - cfg.x0_strength) |
|
|
|
global_state.first_result = x_samples |
|
global_state.first_img = img |
|
|
|
Image.fromarray(x_samples_np[0]).save( |
|
os.path.join(cfg.first_dir, 'first.jpg')) |
|
|
|
return x_samples_np[0] |
|
|
|
|
|
@torch.no_grad() |
|
def process2(*args): |
|
global global_state |
|
global global_video_path |
|
|
|
if global_state.processing_state != ProcessingState.FIRST_IMG: |
|
raise gr.Error('Please generate the first key image before generating' |
|
' all key images') |
|
|
|
cfg = create_cfg(global_video_path, *args) |
|
global_state.update_sd_model(cfg.sd_model, cfg.control_type) |
|
global_state.update_detector(cfg.control_type, cfg.canny_low, |
|
cfg.canny_high) |
|
global_state.processing_state = ProcessingState.KEY_IMGS |
|
|
|
|
|
shutil.rmtree(cfg.key_dir) |
|
os.makedirs(cfg.key_dir, exist_ok=True) |
|
|
|
ddim_v_sampler = global_state.ddim_v_sampler |
|
model = ddim_v_sampler.model |
|
detector = global_state.detector |
|
controller = global_state.controller |
|
flow_model = global_state.flow_model |
|
model.control_scales = [cfg.control_strength] * 13 |
|
|
|
num_samples = 1 |
|
eta = 0.0 |
|
firstx0 = True |
|
pixelfusion = cfg.use_mask |
|
imgs = sorted(os.listdir(cfg.input_dir)) |
|
imgs = [os.path.join(cfg.input_dir, img) for img in imgs] |
|
|
|
first_result = global_state.first_result |
|
first_img = global_state.first_img |
|
pre_result = first_result |
|
pre_img = first_img |
|
|
|
for i in range(0, cfg.frame_count - 1, cfg.interval): |
|
cid = i + 1 |
|
frame = cv2.imread(imgs[i + 1]) |
|
print(cid) |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
img = HWC3(frame) |
|
H, W, C = img.shape |
|
|
|
if cfg.color_preserve or global_state.color_corrections is None: |
|
img_ = numpy2tensor(img) |
|
else: |
|
img_ = apply_color_correction(global_state.color_corrections, |
|
Image.fromarray(img)) |
|
img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 |
|
encoder_posterior = model.encode_first_stage(img_.to(device)) |
|
x0 = model.get_first_stage_encoding(encoder_posterior).detach() |
|
|
|
detected_map = detector(img) |
|
detected_map = HWC3(detected_map) |
|
|
|
control = torch.from_numpy( |
|
detected_map.copy()).float().to(device) / 255.0 |
|
control = torch.stack([control for _ in range(num_samples)], dim=0) |
|
control = einops.rearrange(control, 'b h w c -> b c h w').clone() |
|
cond = { |
|
'c_concat': [control], |
|
'c_crossattn': [ |
|
model.get_learned_conditioning( |
|
[cfg.prompt + ', ' + cfg.a_prompt] * num_samples) |
|
] |
|
} |
|
un_cond = { |
|
'c_concat': [control], |
|
'c_crossattn': |
|
[model.get_learned_conditioning([cfg.n_prompt] * num_samples)] |
|
} |
|
shape = (4, H // 8, W // 8) |
|
|
|
cond['c_concat'] = [control] |
|
un_cond['c_concat'] = [control] |
|
|
|
image1 = torch.from_numpy(pre_img).permute(2, 0, 1).float() |
|
image2 = torch.from_numpy(img).permute(2, 0, 1).float() |
|
warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask( |
|
flow_model, image1, image2, pre_result, False) |
|
blend_mask_pre = blur( |
|
F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4)) |
|
blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1) |
|
|
|
image1 = torch.from_numpy(first_img).permute(2, 0, 1).float() |
|
warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask( |
|
flow_model, image1, image2, first_result, False) |
|
blend_mask_0 = blur( |
|
F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4)) |
|
blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1) |
|
|
|
if firstx0: |
|
mask = 1 - F.max_pool2d(blend_mask_0, kernel_size=8) |
|
controller.set_warp( |
|
F.interpolate(bwd_flow_0 / 8.0, |
|
scale_factor=1. / 8, |
|
mode='bilinear'), mask) |
|
else: |
|
mask = 1 - F.max_pool2d(blend_mask_pre, kernel_size=8) |
|
controller.set_warp( |
|
F.interpolate(bwd_flow_pre / 8.0, |
|
scale_factor=1. / 8, |
|
mode='bilinear'), mask) |
|
|
|
controller.set_task('keepx0, keepstyle') |
|
seed_everything(cfg.seed) |
|
samples, intermediates = ddim_v_sampler.sample( |
|
cfg.ddim_steps, |
|
num_samples, |
|
shape, |
|
cond, |
|
verbose=False, |
|
eta=eta, |
|
unconditional_guidance_scale=cfg.scale, |
|
unconditional_conditioning=un_cond, |
|
controller=controller, |
|
x0=x0, |
|
strength=1 - cfg.x0_strength) |
|
direct_result = model.decode_first_stage(samples) |
|
|
|
if not pixelfusion: |
|
pre_result = direct_result |
|
pre_img = img |
|
viz = ( |
|
einops.rearrange(direct_result, 'b c h w -> b h w c') * 127.5 + |
|
127.5).cpu().numpy().clip(0, 255).astype(np.uint8) |
|
|
|
else: |
|
|
|
blend_results = (1 - blend_mask_pre |
|
) * warped_pre + blend_mask_pre * direct_result |
|
blend_results = ( |
|
1 - blend_mask_0) * warped_0 + blend_mask_0 * blend_results |
|
|
|
bwd_occ = 1 - torch.clamp(1 - bwd_occ_pre + 1 - bwd_occ_0, 0, 1) |
|
blend_mask = blur( |
|
F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4)) |
|
blend_mask = 1 - torch.clamp(blend_mask + bwd_occ, 0, 1) |
|
|
|
encoder_posterior = model.encode_first_stage(blend_results) |
|
xtrg = model.get_first_stage_encoding( |
|
encoder_posterior).detach() |
|
blend_results_rec = model.decode_first_stage(xtrg) |
|
encoder_posterior = model.encode_first_stage(blend_results_rec) |
|
xtrg_rec = model.get_first_stage_encoding( |
|
encoder_posterior).detach() |
|
xtrg_ = (xtrg + 1 * (xtrg - xtrg_rec)) |
|
blend_results_rec_new = model.decode_first_stage(xtrg_) |
|
tmp = (abs(blend_results_rec_new - blend_results).mean( |
|
dim=1, keepdims=True) > 0.25).float() |
|
mask_x = F.max_pool2d((F.interpolate(tmp, |
|
scale_factor=1 / 8., |
|
mode='bilinear') > 0).float(), |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
|
|
mask = (1 - F.max_pool2d(1 - blend_mask, kernel_size=8) |
|
) |
|
|
|
if cfg.smooth_boundary: |
|
noise_rescale = find_flat_region(mask) |
|
else: |
|
noise_rescale = torch.ones_like(mask) |
|
masks = [] |
|
for i in range(cfg.ddim_steps): |
|
if i <= cfg.ddim_steps * cfg.mask_period[ |
|
0] or i >= cfg.ddim_steps * cfg.mask_period[1]: |
|
masks += [None] |
|
else: |
|
masks += [mask * cfg.mask_strength] |
|
|
|
|
|
|
|
|
|
|
|
|
|
xtrg = (xtrg + (1 - mask_x) * (xtrg - xtrg_rec)) * mask |
|
|
|
tasks = 'keepstyle, keepx0' |
|
if not firstx0: |
|
tasks += ', updatex0' |
|
if i % cfg.style_update_freq == 0: |
|
tasks += ', updatestyle' |
|
controller.set_task(tasks, 1.0) |
|
|
|
seed_everything(cfg.seed) |
|
samples, _ = ddim_v_sampler.sample( |
|
cfg.ddim_steps, |
|
num_samples, |
|
shape, |
|
cond, |
|
verbose=False, |
|
eta=eta, |
|
unconditional_guidance_scale=cfg.scale, |
|
unconditional_conditioning=un_cond, |
|
controller=controller, |
|
x0=x0, |
|
strength=1 - cfg.x0_strength, |
|
xtrg=xtrg, |
|
mask=masks, |
|
noise_rescale=noise_rescale) |
|
x_samples = model.decode_first_stage(samples) |
|
pre_result = x_samples |
|
pre_img = img |
|
|
|
viz = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + |
|
127.5).cpu().numpy().clip(0, 255).astype(np.uint8) |
|
|
|
Image.fromarray(viz[0]).save( |
|
os.path.join(cfg.key_dir, f'{cid:04d}.png')) |
|
|
|
key_video_path = os.path.join(cfg.work_dir, 'key.mp4') |
|
fps = get_fps(cfg.input_path) |
|
fps //= cfg.interval |
|
frame_to_video(key_video_path, cfg.key_dir, fps, False) |
|
|
|
return key_video_path |
|
|
|
|
|
DESCRIPTION = ''' |
|
## Rerender A Video |
|
### This space provides the function of key frame translation. Full code for full video translation will be released upon the publication of the paper. |
|
### To avoid overload, we set limitations to the **maximum frame number** (8) and the maximum frame resolution (512x768). |
|
### The running time of a video of size 512x640 is about 1 minute per keyframe under T4 GPU. |
|
### How to use: |
|
1. **Run 1st Key Frame**: only translate the first frame, so you can adjust the prompts/models/parameters to find your ideal output appearance before run the whole video. |
|
2. **Run Key Frames**: translate all the key frames based on the settings of the first frame |
|
3. **Run All**: **Run 1st Key Frame** and **Run Key Frames** |
|
4. **Run Propagation**: propogate the key frames to other frames for full video translation. This part will be released upon the publication of the paper. |
|
### Tips: |
|
1. This method cannot handle large or quick motions where the optical flow is hard to estimate. **Videos with stable motions are preferred**. |
|
2. Pixel-aware fusion may not work for large or quick motions. |
|
3. Try different color-aware AdaIN settings and even unuse it to avoid color jittering. |
|
4. `revAnimated_v11` model for non-photorealstic style, `realisticVisionV20_v20` model for photorealstic style. |
|
5. To use your own SD/LoRA model, you may clone the space and specify your model with [sd_model_cfg.py](https://huggingface.co/spaces/Anonymous-sub/Rerender/blob/main/sd_model_cfg.py). |
|
6. This method is based on the original SD model. You may need to [convert](https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py) Diffuser/Automatic1111 models to the original one. |
|
|
|
**This code is for research purpose and non-commercial use only.** |
|
|
|
[![Duplicate this Space](https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm-dark.svg)](https://huggingface.co/spaces/Anonymous-sub/Rerender?duplicate=true) for no queue on your own hardware. |
|
''' |
|
|
|
block = gr.Blocks().queue() |
|
with block: |
|
with gr.Row(): |
|
gr.Markdown(DESCRIPTION) |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_path = gr.Video(label='Input Video', |
|
source='upload', |
|
format='mp4', |
|
visible=True) |
|
prompt = gr.Textbox(label='Prompt') |
|
seed = gr.Slider(label='Seed', |
|
minimum=0, |
|
maximum=2147483647, |
|
step=1, |
|
value=0, |
|
randomize=True) |
|
run_button = gr.Button(value='Run All') |
|
with gr.Row(): |
|
run_button1 = gr.Button(value='Run 1st Key Frame') |
|
run_button2 = gr.Button(value='Run Key Frames') |
|
run_button3 = gr.Button(value='Run Propagation') |
|
with gr.Accordion('Advanced options for the 1st frame translation', |
|
open=False): |
|
image_resolution = gr.Slider( |
|
label='Frame rsolution', |
|
minimum=256, |
|
maximum=512, |
|
value=512, |
|
step=64, |
|
info='To avoid overload, maximum 512') |
|
control_strength = gr.Slider(label='ControNet strength', |
|
minimum=0.0, |
|
maximum=2.0, |
|
value=1.0, |
|
step=0.01) |
|
x0_strength = gr.Slider( |
|
label='Denoising strength', |
|
minimum=0.00, |
|
maximum=1.05, |
|
value=0.75, |
|
step=0.05, |
|
info=('0: fully recover the input.' |
|
'1.05: fully rerender the input.')) |
|
color_preserve = gr.Checkbox( |
|
label='Preserve color', |
|
value=True, |
|
info='Keep the color of the input video') |
|
with gr.Row(): |
|
left_crop = gr.Slider(label='Left crop length', |
|
minimum=0, |
|
maximum=512, |
|
value=0, |
|
step=1) |
|
right_crop = gr.Slider(label='Right crop length', |
|
minimum=0, |
|
maximum=512, |
|
value=0, |
|
step=1) |
|
with gr.Row(): |
|
top_crop = gr.Slider(label='Top crop length', |
|
minimum=0, |
|
maximum=512, |
|
value=0, |
|
step=1) |
|
bottom_crop = gr.Slider(label='Bottom crop length', |
|
minimum=0, |
|
maximum=512, |
|
value=0, |
|
step=1) |
|
with gr.Row(): |
|
control_type = gr.Dropdown(['HED', 'canny'], |
|
label='Control type', |
|
value='HED') |
|
low_threshold = gr.Slider(label='Canny low threshold', |
|
minimum=1, |
|
maximum=255, |
|
value=100, |
|
step=1) |
|
high_threshold = gr.Slider(label='Canny high threshold', |
|
minimum=1, |
|
maximum=255, |
|
value=200, |
|
step=1) |
|
ddim_steps = gr.Slider(label='Steps', |
|
minimum=1, |
|
maximum=20, |
|
value=20, |
|
step=1, |
|
info='To avoid overload, maximum 20') |
|
scale = gr.Slider(label='CFG scale', |
|
minimum=0.1, |
|
maximum=30.0, |
|
value=7.5, |
|
step=0.1) |
|
sd_model_list = list(model_dict.keys()) |
|
sd_model = gr.Dropdown(sd_model_list, |
|
label='Base model', |
|
value='Stable Diffusion 1.5') |
|
a_prompt = gr.Textbox(label='Added prompt', |
|
value='best quality, extremely detailed') |
|
n_prompt = gr.Textbox( |
|
label='Negative prompt', |
|
value=('longbody, lowres, bad anatomy, bad hands, ' |
|
'missing fingers, extra digit, fewer digits, ' |
|
'cropped, worst quality, low quality')) |
|
with gr.Accordion('Advanced options for the key fame translation', |
|
open=False): |
|
interval = gr.Slider( |
|
label='Key frame frequency (K)', |
|
minimum=1, |
|
maximum=1, |
|
value=1, |
|
step=1, |
|
info='Uniformly sample the key frames every K frames') |
|
keyframe_count = gr.Slider( |
|
label='Number of key frames', |
|
minimum=1, |
|
maximum=1, |
|
value=1, |
|
step=1, |
|
info='To avoid overload, maximum 8 key frames') |
|
|
|
use_constraints = gr.CheckboxGroup( |
|
[ |
|
'shape-aware fusion', 'pixel-aware fusion', |
|
'color-aware AdaIN' |
|
], |
|
label='Select the cross-frame contraints to be used', |
|
value=[ |
|
'shape-aware fusion', 'pixel-aware fusion', |
|
'color-aware AdaIN' |
|
]), |
|
with gr.Row(): |
|
cross_start = gr.Slider( |
|
label='Cross-frame attention start', |
|
minimum=0, |
|
maximum=1, |
|
value=0, |
|
step=0.05) |
|
cross_end = gr.Slider(label='Cross-frame attention end', |
|
minimum=0, |
|
maximum=1, |
|
value=1, |
|
step=0.05) |
|
style_update_freq = gr.Slider( |
|
label='Cross-frame attention update frequency', |
|
minimum=1, |
|
maximum=100, |
|
value=1, |
|
step=1, |
|
info=('Update the key and value for ' |
|
'cross-frame attention every N key frames (recommend N*K>=10)' |
|
)) |
|
with gr.Row(): |
|
warp_start = gr.Slider(label='Shape-aware fusion start', |
|
minimum=0, |
|
maximum=1, |
|
value=0, |
|
step=0.05) |
|
warp_end = gr.Slider(label='Shape-aware fusion end', |
|
minimum=0, |
|
maximum=1, |
|
value=0.1, |
|
step=0.05) |
|
with gr.Row(): |
|
mask_start = gr.Slider(label='Pixel-aware fusion start', |
|
minimum=0, |
|
maximum=1, |
|
value=0.5, |
|
step=0.05) |
|
mask_end = gr.Slider(label='Pixel-aware fusion end', |
|
minimum=0, |
|
maximum=1, |
|
value=0.8, |
|
step=0.05) |
|
with gr.Row(): |
|
ada_start = gr.Slider(label='Color-aware AdaIN start', |
|
minimum=0, |
|
maximum=1, |
|
value=0.8, |
|
step=0.05) |
|
ada_end = gr.Slider(label='Color-aware AdaIN end', |
|
minimum=0, |
|
maximum=1, |
|
value=1, |
|
step=0.05) |
|
mask_strength = gr.Slider(label='Pixel-aware fusion stength', |
|
minimum=0, |
|
maximum=1, |
|
value=0.5, |
|
step=0.01) |
|
inner_strength = gr.Slider( |
|
label='Pixel-aware fusion detail level', |
|
minimum=0.5, |
|
maximum=1, |
|
value=0.9, |
|
step=0.01, |
|
info='Use a low value to prevent artifacts') |
|
smooth_boundary = gr.Checkbox( |
|
label='Smooth fusion boundary', |
|
value=True, |
|
info='Select to prevent artifacts at boundary') |
|
|
|
with gr.Accordion('Example configs', open=True): |
|
config_dir = 'config' |
|
config_list = os.listdir(config_dir) |
|
args_list = [] |
|
for config in config_list: |
|
try: |
|
config_path = os.path.join(config_dir, config) |
|
args = cfg_to_input(config_path) |
|
args_list.append(args) |
|
except FileNotFoundError: |
|
|
|
pass |
|
|
|
ips = [ |
|
prompt, image_resolution, control_strength, color_preserve, |
|
left_crop, right_crop, top_crop, bottom_crop, control_type, |
|
low_threshold, high_threshold, ddim_steps, scale, seed, |
|
sd_model, a_prompt, n_prompt, interval, keyframe_count, |
|
x0_strength, use_constraints[0], cross_start, cross_end, |
|
style_update_freq, warp_start, warp_end, mask_start, |
|
mask_end, ada_start, ada_end, mask_strength, |
|
inner_strength, smooth_boundary |
|
] |
|
|
|
with gr.Column(): |
|
result_image = gr.Image(label='Output first frame', |
|
type='numpy', |
|
interactive=False) |
|
result_keyframe = gr.Video(label='Output key frame video', |
|
format='mp4', |
|
interactive=False) |
|
with gr.Row(): |
|
gr.Examples(examples=args_list, |
|
inputs=[input_path, *ips], |
|
fn=process0, |
|
outputs=[result_image, result_keyframe], |
|
cache_examples=True) |
|
|
|
def input_uploaded(path): |
|
frame_count = get_frame_count(path) |
|
if frame_count <= 2: |
|
raise gr.Error('The input video is too short!' |
|
'Please input another video.') |
|
|
|
default_interval = min(10, frame_count - 2) |
|
max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME) |
|
|
|
global video_frame_count |
|
video_frame_count = frame_count |
|
global global_video_path |
|
global_video_path = path |
|
|
|
return gr.Slider.update(value=default_interval, |
|
maximum=frame_count - 2), gr.Slider.update( |
|
value=max_keyframe, maximum=max_keyframe) |
|
|
|
def input_changed(path): |
|
frame_count = get_frame_count(path) |
|
if frame_count <= 2: |
|
return gr.Slider.update(maximum=1), gr.Slider.update(maximum=1) |
|
|
|
default_interval = min(10, frame_count - 2) |
|
max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME) |
|
|
|
global video_frame_count |
|
video_frame_count = frame_count |
|
global global_video_path |
|
global_video_path = path |
|
|
|
return gr.Slider.update(value=default_interval, |
|
maximum=frame_count - 2), \ |
|
gr.Slider.update(maximum=max_keyframe) |
|
|
|
def interval_changed(interval): |
|
global video_frame_count |
|
if video_frame_count is None: |
|
return gr.Slider.update() |
|
|
|
max_keyframe = min((video_frame_count - 2) // interval, MAX_KEYFRAME) |
|
|
|
return gr.Slider.update(value=max_keyframe, maximum=max_keyframe) |
|
|
|
input_path.change(input_changed, input_path, [interval, keyframe_count]) |
|
input_path.upload(input_uploaded, input_path, [interval, keyframe_count]) |
|
interval.change(interval_changed, interval, keyframe_count) |
|
|
|
run_button.click(fn=process, |
|
inputs=ips, |
|
outputs=[result_image, result_keyframe]) |
|
run_button1.click(fn=process1, inputs=ips, outputs=[result_image]) |
|
run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe]) |
|
|
|
def process3(): |
|
raise gr.Error( |
|
"Coming Soon. Full code for full video translation will be " |
|
"released upon the publication of the paper.") |
|
|
|
run_button3.click(fn=process3, outputs=[result_keyframe]) |
|
|
|
block.queue(concurrency_count=1, max_size=20) |
|
block.launch(server_name='0.0.0.0') |
|
|