Rerender / app.py
patgpt4's picture
Duplicate from Anonymous-sub/Rerender
1239b39
import os
import shutil
from enum import Enum
import cv2
import einops
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from blendmodes.blend import BlendType, blendLayers
from PIL import Image
from pytorch_lightning import seed_everything
from safetensors.torch import load_file
from skimage import exposure
import src.import_util # noqa: F401
from ControlNet.annotator.canny import CannyDetector
from ControlNet.annotator.hed import HEDdetector
from ControlNet.annotator.util import HWC3
from ControlNet.cldm.model import create_model, load_state_dict
from gmflow_module.gmflow.gmflow import GMFlow
from flow.flow_utils import get_warped_and_mask
from sd_model_cfg import model_dict
from src.config import RerenderConfig
from src.controller import AttentionControl
from src.ddim_v_hacked import DDIMVSampler
from src.img_util import find_flat_region, numpy2tensor
from src.video_util import (frame_to_video, get_fps, get_frame_count,
prepare_frames)
import huggingface_hub
REPO_NAME = 'Anonymous-sub/Rerender'
huggingface_hub.hf_hub_download(REPO_NAME,
'pexels-koolshooters-7322716.mp4',
local_dir='videos')
huggingface_hub.hf_hub_download(
REPO_NAME,
'pexels-antoni-shkraba-8048492-540x960-25fps.mp4',
local_dir='videos')
huggingface_hub.hf_hub_download(
REPO_NAME,
'pexels-cottonbro-studio-6649832-960x506-25fps.mp4',
local_dir='videos')
inversed_model_dict = dict()
for k, v in model_dict.items():
inversed_model_dict[v] = k
to_tensor = T.PILToTensor()
blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class ProcessingState(Enum):
NULL = 0
FIRST_IMG = 1
KEY_IMGS = 2
MAX_KEYFRAME = 8
class GlobalState:
def __init__(self):
self.sd_model = None
self.ddim_v_sampler = None
self.detector_type = None
self.detector = None
self.controller = None
self.processing_state = ProcessingState.NULL
flow_model = GMFlow(
feature_channels=128,
num_scales=1,
upsample_factor=8,
num_head=1,
attention_type='swin',
ffn_dim_expansion=4,
num_transformer_layers=6,
).to(device)
checkpoint = torch.load('models/gmflow_sintel-0c07dcb3.pth',
map_location=lambda storage, loc: storage)
weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
flow_model.load_state_dict(weights, strict=False)
flow_model.eval()
self.flow_model = flow_model
def update_controller(self, inner_strength, mask_period, cross_period,
ada_period, warp_period):
self.controller = AttentionControl(inner_strength, mask_period,
cross_period, ada_period,
warp_period)
def update_sd_model(self, sd_model, control_type):
if sd_model == self.sd_model:
return
self.sd_model = sd_model
model = create_model('./ControlNet/models/cldm_v15.yaml').cpu()
if control_type == 'HED':
model.load_state_dict(
load_state_dict(huggingface_hub.hf_hub_download(
'lllyasviel/ControlNet', './models/control_sd15_hed.pth'),
location=device))
elif control_type == 'canny':
model.load_state_dict(
load_state_dict(huggingface_hub.hf_hub_download(
'lllyasviel/ControlNet', 'models/control_sd15_canny.pth'),
location=device))
model.to(device)
sd_model_path = model_dict[sd_model]
if len(sd_model_path) > 0:
repo_name = REPO_NAME
# check if sd_model is repo_id/name otherwise use global REPO_NAME
if sd_model.count('/') == 1:
repo_name = sd_model
model_ext = os.path.splitext(sd_model_path)[1]
downloaded_model = huggingface_hub.hf_hub_download(
repo_name, sd_model_path)
if model_ext == '.safetensors':
model.load_state_dict(load_file(downloaded_model),
strict=False)
elif model_ext == '.ckpt' or model_ext == '.pth':
model.load_state_dict(
torch.load(downloaded_model)['state_dict'], strict=False)
try:
model.first_stage_model.load_state_dict(torch.load(
huggingface_hub.hf_hub_download(
'stabilityai/sd-vae-ft-mse-original',
'vae-ft-mse-840000-ema-pruned.ckpt'))['state_dict'],
strict=False)
except Exception:
print('Warning: We suggest you download the fine-tuned VAE',
'otherwise the generation quality will be degraded')
self.ddim_v_sampler = DDIMVSampler(model)
def clear_sd_model(self):
self.sd_model = None
self.ddim_v_sampler = None
if device == 'cuda':
torch.cuda.empty_cache()
def update_detector(self, control_type, canny_low=100, canny_high=200):
if self.detector_type == control_type:
return
if control_type == 'HED':
self.detector = HEDdetector()
elif control_type == 'canny':
canny_detector = CannyDetector()
low_threshold = canny_low
high_threshold = canny_high
def apply_canny(x):
return canny_detector(x, low_threshold, high_threshold)
self.detector = apply_canny
global_state = GlobalState()
global_video_path = None
video_frame_count = None
def create_cfg(input_path, prompt, image_resolution, control_strength,
color_preserve, left_crop, right_crop, top_crop, bottom_crop,
control_type, low_threshold, high_threshold, ddim_steps, scale,
seed, sd_model, a_prompt, n_prompt, interval, keyframe_count,
x0_strength, use_constraints, cross_start, cross_end,
style_update_freq, warp_start, warp_end, mask_start, mask_end,
ada_start, ada_end, mask_strength, inner_strength,
smooth_boundary):
use_warp = 'shape-aware fusion' in use_constraints
use_mask = 'pixel-aware fusion' in use_constraints
use_ada = 'color-aware AdaIN' in use_constraints
if not use_warp:
warp_start = 1
warp_end = 0
if not use_mask:
mask_start = 1
mask_end = 0
if not use_ada:
ada_start = 1
ada_end = 0
input_name = os.path.split(input_path)[-1].split('.')[0]
frame_count = 2 + keyframe_count * interval
cfg = RerenderConfig()
cfg.create_from_parameters(
input_path,
os.path.join('result', input_name, 'blend.mp4'),
prompt,
a_prompt=a_prompt,
n_prompt=n_prompt,
frame_count=frame_count,
interval=interval,
crop=[left_crop, right_crop, top_crop, bottom_crop],
sd_model=sd_model,
ddim_steps=ddim_steps,
scale=scale,
control_type=control_type,
control_strength=control_strength,
canny_low=low_threshold,
canny_high=high_threshold,
seed=seed,
image_resolution=image_resolution,
x0_strength=x0_strength,
style_update_freq=style_update_freq,
cross_period=(cross_start, cross_end),
warp_period=(warp_start, warp_end),
mask_period=(mask_start, mask_end),
ada_period=(ada_start, ada_end),
mask_strength=mask_strength,
inner_strength=inner_strength,
smooth_boundary=smooth_boundary,
color_preserve=color_preserve)
return cfg
def cfg_to_input(filename):
cfg = RerenderConfig()
cfg.create_from_path(filename)
keyframe_count = (cfg.frame_count - 2) // cfg.interval
use_constraints = [
'shape-aware fusion', 'pixel-aware fusion', 'color-aware AdaIN'
]
sd_model = inversed_model_dict.get(cfg.sd_model, 'Stable Diffusion 1.5')
args = [
cfg.input_path, cfg.prompt, cfg.image_resolution, cfg.control_strength,
cfg.color_preserve, *cfg.crop, cfg.control_type, cfg.canny_low,
cfg.canny_high, cfg.ddim_steps, cfg.scale, cfg.seed, sd_model,
cfg.a_prompt, cfg.n_prompt, cfg.interval, keyframe_count,
cfg.x0_strength, use_constraints, *cfg.cross_period,
cfg.style_update_freq, *cfg.warp_period, *cfg.mask_period,
*cfg.ada_period, cfg.mask_strength, cfg.inner_strength,
cfg.smooth_boundary
]
return args
def setup_color_correction(image):
correction_target = cv2.cvtColor(np.asarray(image.copy()),
cv2.COLOR_RGB2LAB)
return correction_target
def apply_color_correction(correction, original_image):
image = Image.fromarray(
cv2.cvtColor(
exposure.match_histograms(cv2.cvtColor(np.asarray(original_image),
cv2.COLOR_RGB2LAB),
correction,
channel_axis=2),
cv2.COLOR_LAB2RGB).astype('uint8'))
image = blendLayers(image, original_image, BlendType.LUMINOSITY)
return image
@torch.no_grad()
def process(*args):
first_frame = process1(*args)
keypath = process2(*args)
return first_frame, keypath
@torch.no_grad()
def process0(*args):
global global_video_path
global_video_path = args[0]
return process(*args[1:])
@torch.no_grad()
def process1(*args):
global global_video_path
cfg = create_cfg(global_video_path, *args)
global global_state
global_state.update_sd_model(cfg.sd_model, cfg.control_type)
global_state.update_controller(cfg.inner_strength, cfg.mask_period,
cfg.cross_period, cfg.ada_period,
cfg.warp_period)
global_state.update_detector(cfg.control_type, cfg.canny_low,
cfg.canny_high)
global_state.processing_state = ProcessingState.FIRST_IMG
prepare_frames(cfg.input_path, cfg.input_dir, cfg.image_resolution,
cfg.crop)
ddim_v_sampler = global_state.ddim_v_sampler
model = ddim_v_sampler.model
detector = global_state.detector
controller = global_state.controller
model.control_scales = [cfg.control_strength] * 13
model.to(device)
num_samples = 1
eta = 0.0
imgs = sorted(os.listdir(cfg.input_dir))
imgs = [os.path.join(cfg.input_dir, img) for img in imgs]
model.cond_stage_model.device = device
with torch.no_grad():
frame = cv2.imread(imgs[0])
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = HWC3(frame)
H, W, C = img.shape
img_ = numpy2tensor(img)
def generate_first_img(img_, strength):
encoder_posterior = model.encode_first_stage(img_.to(device))
x0 = model.get_first_stage_encoding(encoder_posterior).detach()
detected_map = detector(img)
detected_map = HWC3(detected_map)
control = torch.from_numpy(
detected_map.copy()).float().to(device) / 255.0
control = torch.stack([control for _ in range(num_samples)], dim=0)
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
cond = {
'c_concat': [control],
'c_crossattn': [
model.get_learned_conditioning(
[cfg.prompt + ', ' + cfg.a_prompt] * num_samples)
]
}
un_cond = {
'c_concat': [control],
'c_crossattn':
[model.get_learned_conditioning([cfg.n_prompt] * num_samples)]
}
shape = (4, H // 8, W // 8)
controller.set_task('initfirst')
seed_everything(cfg.seed)
samples, _ = ddim_v_sampler.sample(
cfg.ddim_steps,
num_samples,
shape,
cond,
verbose=False,
eta=eta,
unconditional_guidance_scale=cfg.scale,
unconditional_conditioning=un_cond,
controller=controller,
x0=x0,
strength=strength)
x_samples = model.decode_first_stage(samples)
x_samples_np = (
einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
return x_samples, x_samples_np
# When not preserve color, draw a different frame at first and use its
# color to redraw the first frame.
if not cfg.color_preserve:
first_strength = -1
else:
first_strength = 1 - cfg.x0_strength
x_samples, x_samples_np = generate_first_img(img_, first_strength)
if not cfg.color_preserve:
color_corrections = setup_color_correction(
Image.fromarray(x_samples_np[0]))
global_state.color_corrections = color_corrections
img_ = apply_color_correction(color_corrections,
Image.fromarray(img))
img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1
x_samples, x_samples_np = generate_first_img(
img_, 1 - cfg.x0_strength)
global_state.first_result = x_samples
global_state.first_img = img
Image.fromarray(x_samples_np[0]).save(
os.path.join(cfg.first_dir, 'first.jpg'))
return x_samples_np[0]
@torch.no_grad()
def process2(*args):
global global_state
global global_video_path
if global_state.processing_state != ProcessingState.FIRST_IMG:
raise gr.Error('Please generate the first key image before generating'
' all key images')
cfg = create_cfg(global_video_path, *args)
global_state.update_sd_model(cfg.sd_model, cfg.control_type)
global_state.update_detector(cfg.control_type, cfg.canny_low,
cfg.canny_high)
global_state.processing_state = ProcessingState.KEY_IMGS
# reset key dir
shutil.rmtree(cfg.key_dir)
os.makedirs(cfg.key_dir, exist_ok=True)
ddim_v_sampler = global_state.ddim_v_sampler
model = ddim_v_sampler.model
detector = global_state.detector
controller = global_state.controller
flow_model = global_state.flow_model
model.control_scales = [cfg.control_strength] * 13
num_samples = 1
eta = 0.0
firstx0 = True
pixelfusion = cfg.use_mask
imgs = sorted(os.listdir(cfg.input_dir))
imgs = [os.path.join(cfg.input_dir, img) for img in imgs]
first_result = global_state.first_result
first_img = global_state.first_img
pre_result = first_result
pre_img = first_img
for i in range(0, cfg.frame_count - 1, cfg.interval):
cid = i + 1
frame = cv2.imread(imgs[i + 1])
print(cid)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = HWC3(frame)
H, W, C = img.shape
if cfg.color_preserve or global_state.color_corrections is None:
img_ = numpy2tensor(img)
else:
img_ = apply_color_correction(global_state.color_corrections,
Image.fromarray(img))
img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1
encoder_posterior = model.encode_first_stage(img_.to(device))
x0 = model.get_first_stage_encoding(encoder_posterior).detach()
detected_map = detector(img)
detected_map = HWC3(detected_map)
control = torch.from_numpy(
detected_map.copy()).float().to(device) / 255.0
control = torch.stack([control for _ in range(num_samples)], dim=0)
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
cond = {
'c_concat': [control],
'c_crossattn': [
model.get_learned_conditioning(
[cfg.prompt + ', ' + cfg.a_prompt] * num_samples)
]
}
un_cond = {
'c_concat': [control],
'c_crossattn':
[model.get_learned_conditioning([cfg.n_prompt] * num_samples)]
}
shape = (4, H // 8, W // 8)
cond['c_concat'] = [control]
un_cond['c_concat'] = [control]
image1 = torch.from_numpy(pre_img).permute(2, 0, 1).float()
image2 = torch.from_numpy(img).permute(2, 0, 1).float()
warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask(
flow_model, image1, image2, pre_result, False)
blend_mask_pre = blur(
F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4))
blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1)
image1 = torch.from_numpy(first_img).permute(2, 0, 1).float()
warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
flow_model, image1, image2, first_result, False)
blend_mask_0 = blur(
F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4))
blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1)
if firstx0:
mask = 1 - F.max_pool2d(blend_mask_0, kernel_size=8)
controller.set_warp(
F.interpolate(bwd_flow_0 / 8.0,
scale_factor=1. / 8,
mode='bilinear'), mask)
else:
mask = 1 - F.max_pool2d(blend_mask_pre, kernel_size=8)
controller.set_warp(
F.interpolate(bwd_flow_pre / 8.0,
scale_factor=1. / 8,
mode='bilinear'), mask)
controller.set_task('keepx0, keepstyle')
seed_everything(cfg.seed)
samples, intermediates = ddim_v_sampler.sample(
cfg.ddim_steps,
num_samples,
shape,
cond,
verbose=False,
eta=eta,
unconditional_guidance_scale=cfg.scale,
unconditional_conditioning=un_cond,
controller=controller,
x0=x0,
strength=1 - cfg.x0_strength)
direct_result = model.decode_first_stage(samples)
if not pixelfusion:
pre_result = direct_result
pre_img = img
viz = (
einops.rearrange(direct_result, 'b c h w -> b h w c') * 127.5 +
127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
else:
blend_results = (1 - blend_mask_pre
) * warped_pre + blend_mask_pre * direct_result
blend_results = (
1 - blend_mask_0) * warped_0 + blend_mask_0 * blend_results
bwd_occ = 1 - torch.clamp(1 - bwd_occ_pre + 1 - bwd_occ_0, 0, 1)
blend_mask = blur(
F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4))
blend_mask = 1 - torch.clamp(blend_mask + bwd_occ, 0, 1)
encoder_posterior = model.encode_first_stage(blend_results)
xtrg = model.get_first_stage_encoding(
encoder_posterior).detach() # * mask
blend_results_rec = model.decode_first_stage(xtrg)
encoder_posterior = model.encode_first_stage(blend_results_rec)
xtrg_rec = model.get_first_stage_encoding(
encoder_posterior).detach()
xtrg_ = (xtrg + 1 * (xtrg - xtrg_rec)) # * mask
blend_results_rec_new = model.decode_first_stage(xtrg_)
tmp = (abs(blend_results_rec_new - blend_results).mean(
dim=1, keepdims=True) > 0.25).float()
mask_x = F.max_pool2d((F.interpolate(tmp,
scale_factor=1 / 8.,
mode='bilinear') > 0).float(),
kernel_size=3,
stride=1,
padding=1)
mask = (1 - F.max_pool2d(1 - blend_mask, kernel_size=8)
) # * (1-mask_x)
if cfg.smooth_boundary:
noise_rescale = find_flat_region(mask)
else:
noise_rescale = torch.ones_like(mask)
masks = []
for i in range(cfg.ddim_steps):
if i <= cfg.ddim_steps * cfg.mask_period[
0] or i >= cfg.ddim_steps * cfg.mask_period[1]:
masks += [None]
else:
masks += [mask * cfg.mask_strength]
# mask 3
# xtrg = ((1-mask_x) *
# (xtrg + xtrg - xtrg_rec) + mask_x * samples) * mask
# mask 2
# xtrg = (xtrg + 1 * (xtrg - xtrg_rec)) * mask
xtrg = (xtrg + (1 - mask_x) * (xtrg - xtrg_rec)) * mask # mask 1
tasks = 'keepstyle, keepx0'
if not firstx0:
tasks += ', updatex0'
if i % cfg.style_update_freq == 0:
tasks += ', updatestyle'
controller.set_task(tasks, 1.0)
seed_everything(cfg.seed)
samples, _ = ddim_v_sampler.sample(
cfg.ddim_steps,
num_samples,
shape,
cond,
verbose=False,
eta=eta,
unconditional_guidance_scale=cfg.scale,
unconditional_conditioning=un_cond,
controller=controller,
x0=x0,
strength=1 - cfg.x0_strength,
xtrg=xtrg,
mask=masks,
noise_rescale=noise_rescale)
x_samples = model.decode_first_stage(samples)
pre_result = x_samples
pre_img = img
viz = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
Image.fromarray(viz[0]).save(
os.path.join(cfg.key_dir, f'{cid:04d}.png'))
key_video_path = os.path.join(cfg.work_dir, 'key.mp4')
fps = get_fps(cfg.input_path)
fps //= cfg.interval
frame_to_video(key_video_path, cfg.key_dir, fps, False)
return key_video_path
DESCRIPTION = '''
## Rerender A Video
### This space provides the function of key frame translation. Full code for full video translation will be released upon the publication of the paper.
### To avoid overload, we set limitations to the **maximum frame number** (8) and the maximum frame resolution (512x768).
### The running time of a video of size 512x640 is about 1 minute per keyframe under T4 GPU.
### How to use:
1. **Run 1st Key Frame**: only translate the first frame, so you can adjust the prompts/models/parameters to find your ideal output appearance before run the whole video.
2. **Run Key Frames**: translate all the key frames based on the settings of the first frame
3. **Run All**: **Run 1st Key Frame** and **Run Key Frames**
4. **Run Propagation**: propogate the key frames to other frames for full video translation. This part will be released upon the publication of the paper.
### Tips:
1. This method cannot handle large or quick motions where the optical flow is hard to estimate. **Videos with stable motions are preferred**.
2. Pixel-aware fusion may not work for large or quick motions.
3. Try different color-aware AdaIN settings and even unuse it to avoid color jittering.
4. `revAnimated_v11` model for non-photorealstic style, `realisticVisionV20_v20` model for photorealstic style.
5. To use your own SD/LoRA model, you may clone the space and specify your model with [sd_model_cfg.py](https://huggingface.co/spaces/Anonymous-sub/Rerender/blob/main/sd_model_cfg.py).
6. This method is based on the original SD model. You may need to [convert](https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py) Diffuser/Automatic1111 models to the original one.
**This code is for research purpose and non-commercial use only.**
[![Duplicate this Space](https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm-dark.svg)](https://huggingface.co/spaces/Anonymous-sub/Rerender?duplicate=true) for no queue on your own hardware.
'''
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
input_path = gr.Video(label='Input Video',
source='upload',
format='mp4',
visible=True)
prompt = gr.Textbox(label='Prompt')
seed = gr.Slider(label='Seed',
minimum=0,
maximum=2147483647,
step=1,
value=0,
randomize=True)
run_button = gr.Button(value='Run All')
with gr.Row():
run_button1 = gr.Button(value='Run 1st Key Frame')
run_button2 = gr.Button(value='Run Key Frames')
run_button3 = gr.Button(value='Run Propagation')
with gr.Accordion('Advanced options for the 1st frame translation',
open=False):
image_resolution = gr.Slider(
label='Frame rsolution',
minimum=256,
maximum=512,
value=512,
step=64,
info='To avoid overload, maximum 512')
control_strength = gr.Slider(label='ControNet strength',
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.01)
x0_strength = gr.Slider(
label='Denoising strength',
minimum=0.00,
maximum=1.05,
value=0.75,
step=0.05,
info=('0: fully recover the input.'
'1.05: fully rerender the input.'))
color_preserve = gr.Checkbox(
label='Preserve color',
value=True,
info='Keep the color of the input video')
with gr.Row():
left_crop = gr.Slider(label='Left crop length',
minimum=0,
maximum=512,
value=0,
step=1)
right_crop = gr.Slider(label='Right crop length',
minimum=0,
maximum=512,
value=0,
step=1)
with gr.Row():
top_crop = gr.Slider(label='Top crop length',
minimum=0,
maximum=512,
value=0,
step=1)
bottom_crop = gr.Slider(label='Bottom crop length',
minimum=0,
maximum=512,
value=0,
step=1)
with gr.Row():
control_type = gr.Dropdown(['HED', 'canny'],
label='Control type',
value='HED')
low_threshold = gr.Slider(label='Canny low threshold',
minimum=1,
maximum=255,
value=100,
step=1)
high_threshold = gr.Slider(label='Canny high threshold',
minimum=1,
maximum=255,
value=200,
step=1)
ddim_steps = gr.Slider(label='Steps',
minimum=1,
maximum=20,
value=20,
step=1,
info='To avoid overload, maximum 20')
scale = gr.Slider(label='CFG scale',
minimum=0.1,
maximum=30.0,
value=7.5,
step=0.1)
sd_model_list = list(model_dict.keys())
sd_model = gr.Dropdown(sd_model_list,
label='Base model',
value='Stable Diffusion 1.5')
a_prompt = gr.Textbox(label='Added prompt',
value='best quality, extremely detailed')
n_prompt = gr.Textbox(
label='Negative prompt',
value=('longbody, lowres, bad anatomy, bad hands, '
'missing fingers, extra digit, fewer digits, '
'cropped, worst quality, low quality'))
with gr.Accordion('Advanced options for the key fame translation',
open=False):
interval = gr.Slider(
label='Key frame frequency (K)',
minimum=1,
maximum=1,
value=1,
step=1,
info='Uniformly sample the key frames every K frames')
keyframe_count = gr.Slider(
label='Number of key frames',
minimum=1,
maximum=1,
value=1,
step=1,
info='To avoid overload, maximum 8 key frames')
use_constraints = gr.CheckboxGroup(
[
'shape-aware fusion', 'pixel-aware fusion',
'color-aware AdaIN'
],
label='Select the cross-frame contraints to be used',
value=[
'shape-aware fusion', 'pixel-aware fusion',
'color-aware AdaIN'
]),
with gr.Row():
cross_start = gr.Slider(
label='Cross-frame attention start',
minimum=0,
maximum=1,
value=0,
step=0.05)
cross_end = gr.Slider(label='Cross-frame attention end',
minimum=0,
maximum=1,
value=1,
step=0.05)
style_update_freq = gr.Slider(
label='Cross-frame attention update frequency',
minimum=1,
maximum=100,
value=1,
step=1,
info=('Update the key and value for '
'cross-frame attention every N key frames (recommend N*K>=10)'
))
with gr.Row():
warp_start = gr.Slider(label='Shape-aware fusion start',
minimum=0,
maximum=1,
value=0,
step=0.05)
warp_end = gr.Slider(label='Shape-aware fusion end',
minimum=0,
maximum=1,
value=0.1,
step=0.05)
with gr.Row():
mask_start = gr.Slider(label='Pixel-aware fusion start',
minimum=0,
maximum=1,
value=0.5,
step=0.05)
mask_end = gr.Slider(label='Pixel-aware fusion end',
minimum=0,
maximum=1,
value=0.8,
step=0.05)
with gr.Row():
ada_start = gr.Slider(label='Color-aware AdaIN start',
minimum=0,
maximum=1,
value=0.8,
step=0.05)
ada_end = gr.Slider(label='Color-aware AdaIN end',
minimum=0,
maximum=1,
value=1,
step=0.05)
mask_strength = gr.Slider(label='Pixel-aware fusion stength',
minimum=0,
maximum=1,
value=0.5,
step=0.01)
inner_strength = gr.Slider(
label='Pixel-aware fusion detail level',
minimum=0.5,
maximum=1,
value=0.9,
step=0.01,
info='Use a low value to prevent artifacts')
smooth_boundary = gr.Checkbox(
label='Smooth fusion boundary',
value=True,
info='Select to prevent artifacts at boundary')
with gr.Accordion('Example configs', open=True):
config_dir = 'config'
config_list = os.listdir(config_dir)
args_list = []
for config in config_list:
try:
config_path = os.path.join(config_dir, config)
args = cfg_to_input(config_path)
args_list.append(args)
except FileNotFoundError:
# The video file does not exist, skipped
pass
ips = [
prompt, image_resolution, control_strength, color_preserve,
left_crop, right_crop, top_crop, bottom_crop, control_type,
low_threshold, high_threshold, ddim_steps, scale, seed,
sd_model, a_prompt, n_prompt, interval, keyframe_count,
x0_strength, use_constraints[0], cross_start, cross_end,
style_update_freq, warp_start, warp_end, mask_start,
mask_end, ada_start, ada_end, mask_strength,
inner_strength, smooth_boundary
]
with gr.Column():
result_image = gr.Image(label='Output first frame',
type='numpy',
interactive=False)
result_keyframe = gr.Video(label='Output key frame video',
format='mp4',
interactive=False)
with gr.Row():
gr.Examples(examples=args_list,
inputs=[input_path, *ips],
fn=process0,
outputs=[result_image, result_keyframe],
cache_examples=True)
def input_uploaded(path):
frame_count = get_frame_count(path)
if frame_count <= 2:
raise gr.Error('The input video is too short!'
'Please input another video.')
default_interval = min(10, frame_count - 2)
max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME)
global video_frame_count
video_frame_count = frame_count
global global_video_path
global_video_path = path
return gr.Slider.update(value=default_interval,
maximum=frame_count - 2), gr.Slider.update(
value=max_keyframe, maximum=max_keyframe)
def input_changed(path):
frame_count = get_frame_count(path)
if frame_count <= 2:
return gr.Slider.update(maximum=1), gr.Slider.update(maximum=1)
default_interval = min(10, frame_count - 2)
max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME)
global video_frame_count
video_frame_count = frame_count
global global_video_path
global_video_path = path
return gr.Slider.update(value=default_interval,
maximum=frame_count - 2), \
gr.Slider.update(maximum=max_keyframe)
def interval_changed(interval):
global video_frame_count
if video_frame_count is None:
return gr.Slider.update()
max_keyframe = min((video_frame_count - 2) // interval, MAX_KEYFRAME)
return gr.Slider.update(value=max_keyframe, maximum=max_keyframe)
input_path.change(input_changed, input_path, [interval, keyframe_count])
input_path.upload(input_uploaded, input_path, [interval, keyframe_count])
interval.change(interval_changed, interval, keyframe_count)
run_button.click(fn=process,
inputs=ips,
outputs=[result_image, result_keyframe])
run_button1.click(fn=process1, inputs=ips, outputs=[result_image])
run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe])
def process3():
raise gr.Error(
"Coming Soon. Full code for full video translation will be "
"released upon the publication of the paper.")
run_button3.click(fn=process3, outputs=[result_keyframe])
block.queue(concurrency_count=1, max_size=20)
block.launch(server_name='0.0.0.0')