import os import shutil from enum import Enum import cv2 import einops import gradio as gr import numpy as np import torch import torch.nn.functional as F import torchvision.transforms as T from blendmodes.blend import BlendType, blendLayers from PIL import Image from pytorch_lightning import seed_everything from safetensors.torch import load_file from skimage import exposure import src.import_util # noqa: F401 from ControlNet.annotator.canny import CannyDetector from ControlNet.annotator.hed import HEDdetector from ControlNet.annotator.midas import MidasDetector from ControlNet.annotator.util import HWC3 from ControlNet.cldm.model import create_model, load_state_dict from gmflow_module.gmflow.gmflow import GMFlow from flow.flow_utils import get_warped_and_mask from sd_model_cfg import model_dict from src.config import RerenderConfig from src.controller import AttentionControl from src.ddim_v_hacked import DDIMVSampler from src.img_util import find_flat_region, numpy2tensor from src.video_util import (frame_to_video, get_fps, get_frame_count, prepare_frames) import huggingface_hub REPO_NAME = 'Anonymous-sub/Rerender' huggingface_hub.hf_hub_download(REPO_NAME, 'pexels-koolshooters-7322716.mp4', local_dir='videos') huggingface_hub.hf_hub_download( REPO_NAME, 'pexels-antoni-shkraba-8048492-540x960-25fps.mp4', local_dir='videos') huggingface_hub.hf_hub_download( REPO_NAME, 'pexels-cottonbro-studio-6649832-960x506-25fps.mp4', local_dir='videos') inversed_model_dict = dict() for k, v in model_dict.items(): inversed_model_dict[v] = k to_tensor = T.PILToTensor() blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18)) device = 'cuda' if torch.cuda.is_available() else 'cpu' class ProcessingState(Enum): NULL = 0 FIRST_IMG = 1 KEY_IMGS = 2 MAX_KEYFRAME = float(os.environ.get('MAX_KEYFRAME', 8)) class GlobalState: def __init__(self): self.sd_model = None self.ddim_v_sampler = None self.detector_type = None self.detector = None self.controller = None self.processing_state = ProcessingState.NULL flow_model = GMFlow( feature_channels=128, num_scales=1, upsample_factor=8, num_head=1, attention_type='swin', ffn_dim_expansion=4, num_transformer_layers=6, ).to(device) checkpoint = torch.load('models/gmflow_sintel-0c07dcb3.pth', map_location=lambda storage, loc: storage) weights = checkpoint['model'] if 'model' in checkpoint else checkpoint flow_model.load_state_dict(weights, strict=False) flow_model.eval() self.flow_model = flow_model def update_controller(self, inner_strength, mask_period, cross_period, ada_period, warp_period): self.controller = AttentionControl(inner_strength, mask_period, cross_period, ada_period, warp_period) def update_sd_model(self, sd_model, control_type): if sd_model == self.sd_model: return self.sd_model = sd_model model = create_model('./ControlNet/models/cldm_v15.yaml').cpu() if control_type == 'HED': model.load_state_dict( load_state_dict(huggingface_hub.hf_hub_download( 'lllyasviel/ControlNet', './models/control_sd15_hed.pth'), location=device)) elif control_type == 'canny': model.load_state_dict( load_state_dict(huggingface_hub.hf_hub_download( 'lllyasviel/ControlNet', 'models/control_sd15_canny.pth'), location=device)) elif control_type == 'depth': model.load_state_dict( load_state_dict(huggingface_hub.hf_hub_download( 'lllyasviel/ControlNet', 'models/control_sd15_depth.pth'), location=device)) model.to(device) sd_model_path = model_dict[sd_model] if len(sd_model_path) > 0: repo_name = REPO_NAME # check if sd_model is repo_id/name otherwise use global REPO_NAME if sd_model.count('/') == 1: repo_name = sd_model model_ext = os.path.splitext(sd_model_path)[1] downloaded_model = huggingface_hub.hf_hub_download( repo_name, sd_model_path) if model_ext == '.safetensors': model.load_state_dict(load_file(downloaded_model), strict=False) elif model_ext == '.ckpt' or model_ext == '.pth': model.load_state_dict( torch.load(downloaded_model)['state_dict'], strict=False) try: model.first_stage_model.load_state_dict(torch.load( huggingface_hub.hf_hub_download( 'stabilityai/sd-vae-ft-mse-original', 'vae-ft-mse-840000-ema-pruned.ckpt'))['state_dict'], strict=False) except Exception: print('Warning: We suggest you download the fine-tuned VAE', 'otherwise the generation quality will be degraded') self.ddim_v_sampler = DDIMVSampler(model) def clear_sd_model(self): self.sd_model = None self.ddim_v_sampler = None if device == 'cuda': torch.cuda.empty_cache() def update_detector(self, control_type, canny_low=100, canny_high=200): if self.detector_type == control_type: return if control_type == 'HED': self.detector = HEDdetector() elif control_type == 'canny': canny_detector = CannyDetector() low_threshold = canny_low high_threshold = canny_high def apply_canny(x): return canny_detector(x, low_threshold, high_threshold) self.detector = apply_canny elif control_type == 'depth': midas = MidasDetector() def apply_midas(x): detected_map, _ = midas(x) return detected_map self.detector = apply_midas global_state = GlobalState() global_video_path = None video_frame_count = None def create_cfg(input_path, prompt, image_resolution, control_strength, color_preserve, left_crop, right_crop, top_crop, bottom_crop, control_type, low_threshold, high_threshold, ddim_steps, scale, seed, sd_model, a_prompt, n_prompt, interval, keyframe_count, x0_strength, use_constraints, cross_start, cross_end, style_update_freq, warp_start, warp_end, mask_start, mask_end, ada_start, ada_end, mask_strength, inner_strength, smooth_boundary): use_warp = 'shape-aware fusion' in use_constraints use_mask = 'pixel-aware fusion' in use_constraints use_ada = 'color-aware AdaIN' in use_constraints if not use_warp: warp_start = 1 warp_end = 0 if not use_mask: mask_start = 1 mask_end = 0 if not use_ada: ada_start = 1 ada_end = 0 input_name = os.path.split(input_path)[-1].split('.')[0] frame_count = 2 + keyframe_count * interval cfg = RerenderConfig() cfg.create_from_parameters( input_path, os.path.join('result', input_name, 'blend.mp4'), prompt, a_prompt=a_prompt, n_prompt=n_prompt, frame_count=frame_count, interval=interval, crop=[left_crop, right_crop, top_crop, bottom_crop], sd_model=sd_model, ddim_steps=ddim_steps, scale=scale, control_type=control_type, control_strength=control_strength, canny_low=low_threshold, canny_high=high_threshold, seed=seed, image_resolution=image_resolution, x0_strength=x0_strength, style_update_freq=style_update_freq, cross_period=(cross_start, cross_end), warp_period=(warp_start, warp_end), mask_period=(mask_start, mask_end), ada_period=(ada_start, ada_end), mask_strength=mask_strength, inner_strength=inner_strength, smooth_boundary=smooth_boundary, color_preserve=color_preserve) return cfg def cfg_to_input(filename): cfg = RerenderConfig() cfg.create_from_path(filename) keyframe_count = (cfg.frame_count - 2) // cfg.interval use_constraints = [ 'shape-aware fusion', 'pixel-aware fusion', 'color-aware AdaIN' ] sd_model = inversed_model_dict.get(cfg.sd_model, 'Stable Diffusion 1.5') args = [ cfg.input_path, cfg.prompt, cfg.image_resolution, cfg.control_strength, cfg.color_preserve, *cfg.crop, cfg.control_type, cfg.canny_low, cfg.canny_high, cfg.ddim_steps, cfg.scale, cfg.seed, sd_model, cfg.a_prompt, cfg.n_prompt, cfg.interval, keyframe_count, cfg.x0_strength, use_constraints, *cfg.cross_period, cfg.style_update_freq, *cfg.warp_period, *cfg.mask_period, *cfg.ada_period, cfg.mask_strength, cfg.inner_strength, cfg.smooth_boundary ] return args def setup_color_correction(image): correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB) return correction_target def apply_color_correction(correction, original_image): image = Image.fromarray( cv2.cvtColor( exposure.match_histograms(cv2.cvtColor(np.asarray(original_image), cv2.COLOR_RGB2LAB), correction, channel_axis=2), cv2.COLOR_LAB2RGB).astype('uint8')) image = blendLayers(image, original_image, BlendType.LUMINOSITY) return image @torch.no_grad() def process(*args): first_frame = process1(*args) keypath = process2(*args) return first_frame, keypath @torch.no_grad() def process0(*args): global global_video_path global_video_path = args[0] return process(*args[1:]) @torch.no_grad() def process1(*args): global global_video_path cfg = create_cfg(global_video_path, *args) global global_state global_state.update_sd_model(cfg.sd_model, cfg.control_type) global_state.update_controller(cfg.inner_strength, cfg.mask_period, cfg.cross_period, cfg.ada_period, cfg.warp_period) global_state.update_detector(cfg.control_type, cfg.canny_low, cfg.canny_high) global_state.processing_state = ProcessingState.FIRST_IMG prepare_frames(cfg.input_path, cfg.input_dir, cfg.image_resolution, cfg.crop) ddim_v_sampler = global_state.ddim_v_sampler model = ddim_v_sampler.model detector = global_state.detector controller = global_state.controller model.control_scales = [cfg.control_strength] * 13 model.to(device) num_samples = 1 eta = 0.0 imgs = sorted(os.listdir(cfg.input_dir)) imgs = [os.path.join(cfg.input_dir, img) for img in imgs] model.cond_stage_model.device = device with torch.no_grad(): frame = cv2.imread(imgs[0]) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) img = HWC3(frame) H, W, C = img.shape img_ = numpy2tensor(img) def generate_first_img(img_, strength): encoder_posterior = model.encode_first_stage(img_.to(device)) x0 = model.get_first_stage_encoding(encoder_posterior).detach() detected_map = detector(img) detected_map = HWC3(detected_map) control = torch.from_numpy( detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() cond = { 'c_concat': [control], 'c_crossattn': [ model.get_learned_conditioning( [cfg.prompt + ', ' + cfg.a_prompt] * num_samples) ] } un_cond = { 'c_concat': [control], 'c_crossattn': [model.get_learned_conditioning([cfg.n_prompt] * num_samples)] } shape = (4, H // 8, W // 8) controller.set_task('initfirst') seed_everything(cfg.seed) samples, _ = ddim_v_sampler.sample( cfg.ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=cfg.scale, unconditional_conditioning=un_cond, controller=controller, x0=x0, strength=strength) x_samples = model.decode_first_stage(samples) x_samples_np = ( einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) return x_samples, x_samples_np # When not preserve color, draw a different frame at first and use its # color to redraw the first frame. if not cfg.color_preserve: first_strength = -1 else: first_strength = 1 - cfg.x0_strength x_samples, x_samples_np = generate_first_img(img_, first_strength) if not cfg.color_preserve: color_corrections = setup_color_correction( Image.fromarray(x_samples_np[0])) global_state.color_corrections = color_corrections img_ = apply_color_correction(color_corrections, Image.fromarray(img)) img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 x_samples, x_samples_np = generate_first_img( img_, 1 - cfg.x0_strength) global_state.first_result = x_samples global_state.first_img = img Image.fromarray(x_samples_np[0]).save( os.path.join(cfg.first_dir, 'first.jpg')) return x_samples_np[0] @torch.no_grad() def process2(*args): global global_state global global_video_path if global_state.processing_state != ProcessingState.FIRST_IMG: raise gr.Error('Please generate the first key image before generating' ' all key images') cfg = create_cfg(global_video_path, *args) global_state.update_sd_model(cfg.sd_model, cfg.control_type) global_state.update_detector(cfg.control_type, cfg.canny_low, cfg.canny_high) global_state.processing_state = ProcessingState.KEY_IMGS # reset key dir shutil.rmtree(cfg.key_dir) os.makedirs(cfg.key_dir, exist_ok=True) ddim_v_sampler = global_state.ddim_v_sampler model = ddim_v_sampler.model detector = global_state.detector controller = global_state.controller flow_model = global_state.flow_model model.control_scales = [cfg.control_strength] * 13 num_samples = 1 eta = 0.0 firstx0 = True pixelfusion = cfg.use_mask imgs = sorted(os.listdir(cfg.input_dir)) imgs = [os.path.join(cfg.input_dir, img) for img in imgs] first_result = global_state.first_result first_img = global_state.first_img pre_result = first_result pre_img = first_img for i in range(0, cfg.frame_count - 1, cfg.interval): cid = i + 1 frame = cv2.imread(imgs[i + 1]) print(cid) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) img = HWC3(frame) H, W, C = img.shape if cfg.color_preserve or global_state.color_corrections is None: img_ = numpy2tensor(img) else: img_ = apply_color_correction(global_state.color_corrections, Image.fromarray(img)) img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 encoder_posterior = model.encode_first_stage(img_.to(device)) x0 = model.get_first_stage_encoding(encoder_posterior).detach() detected_map = detector(img) detected_map = HWC3(detected_map) control = torch.from_numpy( detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() cond = { 'c_concat': [control], 'c_crossattn': [ model.get_learned_conditioning( [cfg.prompt + ', ' + cfg.a_prompt] * num_samples) ] } un_cond = { 'c_concat': [control], 'c_crossattn': [model.get_learned_conditioning([cfg.n_prompt] * num_samples)] } shape = (4, H // 8, W // 8) cond['c_concat'] = [control] un_cond['c_concat'] = [control] image1 = torch.from_numpy(pre_img).permute(2, 0, 1).float() image2 = torch.from_numpy(img).permute(2, 0, 1).float() warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask( flow_model, image1, image2, pre_result, False) blend_mask_pre = blur( F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4)) blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1) image1 = torch.from_numpy(first_img).permute(2, 0, 1).float() warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask( flow_model, image1, image2, first_result, False) blend_mask_0 = blur( F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4)) blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1) if firstx0: mask = 1 - F.max_pool2d(blend_mask_0, kernel_size=8) controller.set_warp( F.interpolate(bwd_flow_0 / 8.0, scale_factor=1. / 8, mode='bilinear'), mask) else: mask = 1 - F.max_pool2d(blend_mask_pre, kernel_size=8) controller.set_warp( F.interpolate(bwd_flow_pre / 8.0, scale_factor=1. / 8, mode='bilinear'), mask) controller.set_task('keepx0, keepstyle') seed_everything(cfg.seed) samples, intermediates = ddim_v_sampler.sample( cfg.ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=cfg.scale, unconditional_conditioning=un_cond, controller=controller, x0=x0, strength=1 - cfg.x0_strength) direct_result = model.decode_first_stage(samples) if not pixelfusion: pre_result = direct_result pre_img = img viz = ( einops.rearrange(direct_result, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) else: blend_results = (1 - blend_mask_pre ) * warped_pre + blend_mask_pre * direct_result blend_results = ( 1 - blend_mask_0) * warped_0 + blend_mask_0 * blend_results bwd_occ = 1 - torch.clamp(1 - bwd_occ_pre + 1 - bwd_occ_0, 0, 1) blend_mask = blur( F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4)) blend_mask = 1 - torch.clamp(blend_mask + bwd_occ, 0, 1) encoder_posterior = model.encode_first_stage(blend_results) xtrg = model.get_first_stage_encoding( encoder_posterior).detach() # * mask blend_results_rec = model.decode_first_stage(xtrg) encoder_posterior = model.encode_first_stage(blend_results_rec) xtrg_rec = model.get_first_stage_encoding( encoder_posterior).detach() xtrg_ = (xtrg + 1 * (xtrg - xtrg_rec)) # * mask blend_results_rec_new = model.decode_first_stage(xtrg_) tmp = (abs(blend_results_rec_new - blend_results).mean( dim=1, keepdims=True) > 0.25).float() mask_x = F.max_pool2d((F.interpolate(tmp, scale_factor=1 / 8., mode='bilinear') > 0).float(), kernel_size=3, stride=1, padding=1) mask = (1 - F.max_pool2d(1 - blend_mask, kernel_size=8) ) # * (1-mask_x) if cfg.smooth_boundary: noise_rescale = find_flat_region(mask) else: noise_rescale = torch.ones_like(mask) masks = [] for i in range(cfg.ddim_steps): if i <= cfg.ddim_steps * cfg.mask_period[ 0] or i >= cfg.ddim_steps * cfg.mask_period[1]: masks += [None] else: masks += [mask * cfg.mask_strength] # mask 3 # xtrg = ((1-mask_x) * # (xtrg + xtrg - xtrg_rec) + mask_x * samples) * mask # mask 2 # xtrg = (xtrg + 1 * (xtrg - xtrg_rec)) * mask xtrg = (xtrg + (1 - mask_x) * (xtrg - xtrg_rec)) * mask # mask 1 tasks = 'keepstyle, keepx0' if not firstx0: tasks += ', updatex0' if i % cfg.style_update_freq == 0: tasks += ', updatestyle' controller.set_task(tasks, 1.0) seed_everything(cfg.seed) samples, _ = ddim_v_sampler.sample( cfg.ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=cfg.scale, unconditional_conditioning=un_cond, controller=controller, x0=x0, strength=1 - cfg.x0_strength, xtrg=xtrg, mask=masks, noise_rescale=noise_rescale) x_samples = model.decode_first_stage(samples) pre_result = x_samples pre_img = img viz = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) Image.fromarray(viz[0]).save( os.path.join(cfg.key_dir, f'{cid:04d}.png')) key_video_path = os.path.join(cfg.work_dir, 'key.mp4') fps = get_fps(cfg.input_path) fps //= cfg.interval frame_to_video(key_video_path, cfg.key_dir, fps, False) return key_video_path DESCRIPTION = ''' ## Rerender A Video ### This space provides the function of key frame translation. Full code for full video translation will be released upon the publication of the paper. ### To avoid overload, we set limitations to the **maximum frame number** (8) and the maximum frame resolution (512x768). ### The running time of a video of size 512x640 is about 1 minute per keyframe under T4 GPU. ### How to use: 1. **Run 1st Key Frame**: only translate the first frame, so you can adjust the prompts/models/parameters to find your ideal output appearance before run the whole video. 2. **Run Key Frames**: translate all the key frames based on the settings of the first frame 3. **Run All**: **Run 1st Key Frame** and **Run Key Frames** 4. **Run Propagation**: propogate the key frames to other frames for full video translation. This part will be released upon the publication of the paper. ### Tips: 1. This method cannot handle large or quick motions where the optical flow is hard to estimate. **Videos with stable motions are preferred**. 2. Pixel-aware fusion may not work for large or quick motions. 3. Try different color-aware AdaIN settings and even unuse it to avoid color jittering. 4. `revAnimated_v11` model for non-photorealstic style, `realisticVisionV20_v20` model for photorealstic style. 5. To use your own SD/LoRA model, you may clone the space and specify your model with [sd_model_cfg.py](https://huggingface.co/spaces/Anonymous-sub/Rerender/blob/main/sd_model_cfg.py). 6. This method is based on the original SD model. You may need to [convert](https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py) Diffuser/Automatic1111 models to the original one. **This code is for research purpose and non-commercial use only.** [![Duplicate this Space](https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm-dark.svg)](https://huggingface.co/spaces/Anonymous-sub/Rerender?duplicate=true) for no queue on your own hardware. ''' block = gr.Blocks().queue() with block: with gr.Row(): gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): input_path = gr.Video(label='Input Video', source='upload', format='mp4', visible=True) prompt = gr.Textbox(label='Prompt') seed = gr.Slider(label='Seed', minimum=0, maximum=2147483647, step=1, value=0, randomize=True) run_button = gr.Button(value='Run All') with gr.Row(): run_button1 = gr.Button(value='Run 1st Key Frame') run_button2 = gr.Button(value='Run Key Frames') run_button3 = gr.Button(value='Run Propagation') with gr.Accordion('Advanced options for the 1st frame translation', open=False): image_resolution = gr.Slider( label='Frame rsolution', minimum=256, maximum=512, value=512, step=64, info='To avoid overload, maximum 512') control_strength = gr.Slider(label='ControNet strength', minimum=0.0, maximum=2.0, value=1.0, step=0.01) x0_strength = gr.Slider( label='Denoising strength', minimum=0.00, maximum=1.05, value=0.75, step=0.05, info=('0: fully recover the input.' '1.05: fully rerender the input.')) color_preserve = gr.Checkbox( label='Preserve color', value=True, info='Keep the color of the input video') with gr.Row(): left_crop = gr.Slider(label='Left crop length', minimum=0, maximum=512, value=0, step=1) right_crop = gr.Slider(label='Right crop length', minimum=0, maximum=512, value=0, step=1) with gr.Row(): top_crop = gr.Slider(label='Top crop length', minimum=0, maximum=512, value=0, step=1) bottom_crop = gr.Slider(label='Bottom crop length', minimum=0, maximum=512, value=0, step=1) with gr.Row(): control_type = gr.Dropdown(['HED', 'canny', 'depth'], label='Control type', value='HED') low_threshold = gr.Slider(label='Canny low threshold', minimum=1, maximum=255, value=100, step=1) high_threshold = gr.Slider(label='Canny high threshold', minimum=1, maximum=255, value=200, step=1) ddim_steps = gr.Slider(label='Steps', minimum=1, maximum=20, value=20, step=1, info='To avoid overload, maximum 20') scale = gr.Slider(label='CFG scale', minimum=0.1, maximum=30.0, value=7.5, step=0.1) sd_model_list = list(model_dict.keys()) sd_model = gr.Dropdown(sd_model_list, label='Base model', value='Stable Diffusion 1.5') a_prompt = gr.Textbox(label='Added prompt', value='best quality, extremely detailed') n_prompt = gr.Textbox( label='Negative prompt', value=('longbody, lowres, bad anatomy, bad hands, ' 'missing fingers, extra digit, fewer digits, ' 'cropped, worst quality, low quality')) with gr.Accordion('Advanced options for the key fame translation', open=False): interval = gr.Slider( label='Key frame frequency (K)', minimum=1, maximum=MAX_KEYFRAME, value=1, step=1, info='Uniformly sample the key frames every K frames') keyframe_count = gr.Slider( label='Number of key frames', minimum=1, maximum=MAX_KEYFRAME, value=1, step=1, info='To avoid overload, maximum 8 key frames') use_constraints = gr.CheckboxGroup( [ 'shape-aware fusion', 'pixel-aware fusion', 'color-aware AdaIN' ], label='Select the cross-frame contraints to be used', value=[ 'shape-aware fusion', 'pixel-aware fusion', 'color-aware AdaIN' ]), with gr.Row(): cross_start = gr.Slider( label='Cross-frame attention start', minimum=0, maximum=1, value=0, step=0.05) cross_end = gr.Slider(label='Cross-frame attention end', minimum=0, maximum=1, value=1, step=0.05) style_update_freq = gr.Slider( label='Cross-frame attention update frequency', minimum=1, maximum=100, value=1, step=1, info=('Update the key and value for ' 'cross-frame attention every N key frames (recommend N*K>=10)' )) with gr.Row(): warp_start = gr.Slider(label='Shape-aware fusion start', minimum=0, maximum=1, value=0, step=0.05) warp_end = gr.Slider(label='Shape-aware fusion end', minimum=0, maximum=1, value=0.1, step=0.05) with gr.Row(): mask_start = gr.Slider(label='Pixel-aware fusion start', minimum=0, maximum=1, value=0.5, step=0.05) mask_end = gr.Slider(label='Pixel-aware fusion end', minimum=0, maximum=1, value=0.8, step=0.05) with gr.Row(): ada_start = gr.Slider(label='Color-aware AdaIN start', minimum=0, maximum=1, value=0.8, step=0.05) ada_end = gr.Slider(label='Color-aware AdaIN end', minimum=0, maximum=1, value=1, step=0.05) mask_strength = gr.Slider(label='Pixel-aware fusion stength', minimum=0, maximum=1, value=0.5, step=0.01) inner_strength = gr.Slider( label='Pixel-aware fusion detail level', minimum=0.5, maximum=1, value=0.9, step=0.01, info='Use a low value to prevent artifacts') smooth_boundary = gr.Checkbox( label='Smooth fusion boundary', value=True, info='Select to prevent artifacts at boundary') with gr.Accordion('Example configs', open=True): config_dir = 'config' config_list = os.listdir(config_dir) args_list = [] for config in config_list: try: config_path = os.path.join(config_dir, config) args = cfg_to_input(config_path) args_list.append(args) except FileNotFoundError: # The video file does not exist, skipped pass ips = [ prompt, image_resolution, control_strength, color_preserve, left_crop, right_crop, top_crop, bottom_crop, control_type, low_threshold, high_threshold, ddim_steps, scale, seed, sd_model, a_prompt, n_prompt, interval, keyframe_count, x0_strength, use_constraints[0], cross_start, cross_end, style_update_freq, warp_start, warp_end, mask_start, mask_end, ada_start, ada_end, mask_strength, inner_strength, smooth_boundary ] with gr.Column(): result_image = gr.Image(label='Output first frame', type='numpy', interactive=False) result_keyframe = gr.Video(label='Output key frame video', format='mp4', interactive=False) with gr.Row(): gr.Examples(examples=args_list, inputs=[input_path, *ips], fn=process0, outputs=[result_image, result_keyframe], cache_examples=True) def input_uploaded(path): frame_count = get_frame_count(path) if frame_count <= 2: raise gr.Error('The input video is too short!' 'Please input another video.') default_interval = min(10, frame_count - 2) max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME) global video_frame_count video_frame_count = frame_count global global_video_path global_video_path = path return gr.Slider.update(value=default_interval, maximum=frame_count - 2), gr.Slider.update( value=max_keyframe, maximum=max_keyframe) def input_changed(path): frame_count = get_frame_count(path) if frame_count <= 2: return gr.Slider.update(maximum=1), gr.Slider.update(maximum=1) default_interval = min(10, frame_count - 2) max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME) global video_frame_count video_frame_count = frame_count global global_video_path global_video_path = path return gr.Slider.update(value=default_interval, maximum=frame_count - 2), \ gr.Slider.update(maximum=max_keyframe) def interval_changed(interval): global video_frame_count if video_frame_count is None: return gr.Slider.update() max_keyframe = min((video_frame_count - 2) // interval, MAX_KEYFRAME) return gr.Slider.update(value=max_keyframe, maximum=max_keyframe) input_path.change(input_changed, input_path, [interval, keyframe_count]) input_path.upload(input_uploaded, input_path, [interval, keyframe_count]) interval.change(interval_changed, interval, keyframe_count) run_button.click(fn=process, inputs=ips, outputs=[result_image, result_keyframe]) run_button1.click(fn=process1, inputs=ips, outputs=[result_image]) run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe]) def process3(): raise gr.Error( "Coming Soon. Full code for full video translation will be " "released upon the publication of the paper.") run_button3.click(fn=process3, outputs=[result_keyframe]) block.queue(concurrency_count=1, max_size=20) block.launch(server_name='0.0.0.0')