import gradio as gr import spaces import numpy as np import torch import cv2 import os import imageio from diffusers import StableDiffusionControlNetPipeline, ControlNetModel from controlnet_aux import LineartDetector from functools import partial from PIL import Image from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, ToTensor, Normalize, Resize from NaRCan_model import Homography, Siren from util import get_mgrid, apply_homography, jacobian, VideoFitting, TestVideoFitting device = 'cuda' if torch.cuda.is_available() else 'cpu' def get_example(): case = [ [ 'examples/bear.mp4', ], [ 'examples/boat.mp4', ], [ 'examples/woman-drink.mp4', ], [ 'examples/corgi.mp4', ], [ 'examples/yacht.mp4', ], [ 'examples/koolshooters.mp4', ], [ 'examples/overlook-the-ocean.mp4', ], [ 'examples/rotate.mp4', ], [ 'examples/shark-ocean.mp4', ], [ 'examples/surf.mp4', ], [ 'examples/cactus.mp4', ], [ 'examples/gold-fish.mp4', ] ] return case def set_default_prompt(video_name): video_to_prompt = { 'bear.mp4': 'bear, Van Gogh Style', 'boat.mp4': 'a burning boat sails on lava', 'cactus.mp4': 'cactus, made of paper', 'corgi.mp4': 'a hellhound', 'gold-fish.mp4': 'Goldfish in the Milky Way', 'koolshooters.mp4': 'Avatar', 'overlook-the-ocean.mp4': 'ocean, pixel style', 'rotate.mp4': 'turbine engine', 'shark-ocean.mp4': 'A sleek shark, cartoon style', 'surf.mp4': 'Sailing, The background is a large white cloud, sketch style', 'woman-drink.mp4': 'a drinking zombie', 'yacht.mp4': 'yacht, cyberpunk style', } return video_to_prompt.get(video_name, '') def update_prompt(input_video): video_name = input_video.split('/')[-1] return set_default_prompt(video_name) # Map videos to corresponding images video_to_image = { 'bear.mp4': ['canonical/bear.png', 'pth_file/bear', 'examples_frames/bear'], 'boat.mp4': ['canonical/boat.png', 'pth_file/boat', 'examples_frames/boat'], 'cactus.mp4': ['canonical/cactus.png', 'pth_file/cactus', 'examples_frames/cactus'], 'corgi.mp4': ['canonical/corgi.png', 'pth_file/corgi', 'examples_frames/corgi'], 'gold-fish.mp4': ['canonical/gold-fish.png', 'pth_file/gold-fish', 'examples_frames/gold-fish'], 'koolshooters.mp4': ['canonical/koolshooters.png', 'pth_file/koolshooters', 'examples_frames/koolshooters'], 'overlook-the-ocean.mp4': ['canonical/overlook-the-ocean.png', 'pth_file/overlook-the-ocean', 'examples_frames/overlook-the-ocean'], 'rotate.mp4': ['canonical/rotate.png', 'pth_file/rotate', 'examples_frames/rotate'], 'shark-ocean.mp4': ['canonical/shark-ocean.png', 'pth_file/shark-ocean', 'examples_frames/shark-ocean'], 'surf.mp4': ['canonical/surf.png', 'pth_file/surf', 'examples_frames/surf'], 'woman-drink.mp4': ['canonical/woman-drink.png', 'pth_file/woman-drink', 'examples_frames/woman-drink'], 'yacht.mp4': ['canonical/yacht.png', 'pth_file/yacht', 'examples_frames/yacht'], } def images_to_video(image_list, output_path, fps=10): # Convert PIL Images to numpy arrays frames = [np.array(img).astype(np.uint8) for img in image_list] frames = frames[:20] # Create video writer writer = imageio.get_writer(output_path, fps=fps, codec='libx264') for frame in frames: writer.append_data(frame) writer.close() @spaces.GPU def NaRCan_make_video(edit_canonical, pth_path, frames_path): # load NaRCan model checkpoint_g_old = torch.load(os.path.join(pth_path, "homography_g.pth")) checkpoint_g = torch.load(os.path.join(pth_path, "mlp_g.pth")) g_old = Homography(hidden_features=256, hidden_layers=2).to(device) g = Siren(in_features=3, out_features=2, hidden_features=256, hidden_layers=5, outermost_linear=True).to(device) g_old.load_state_dict(checkpoint_g_old) g.load_state_dict(checkpoint_g) g_old.eval() g.eval() transform = Compose([ Resize(512), ToTensor(), Normalize(torch.Tensor([0.5, 0.5, 0.5]), torch.Tensor([0.5, 0.5, 0.5])) ]) v = TestVideoFitting(frames_path, transform) videoloader = DataLoader(v, batch_size=1, pin_memory=True, num_workers=0) model_input, ground_truth = next(iter(videoloader)) model_input, ground_truth = model_input[0].to(device), ground_truth[0].to(device) myoutput = None data_len = len(os.listdir(frames_path)) with torch.no_grad(): batch_size = (v.H * v.W) for step in range(data_len): start = (step * batch_size) % len(model_input) end = min(start + batch_size, len(model_input)) # get the deformation xy, t = model_input[start:end, :-1], model_input[start:end, [-1]] xyt = model_input[start:end] h_old = apply_homography(xy, g_old(t)) h = g(xyt) xy_ = h_old + h # use canonical to reconstruct w, h = v.W, v.H canonical_img = np.array(edit_canonical.convert('RGB')) canonical_img = torch.from_numpy(canonical_img).float().to(device) h_c, w_c = canonical_img.shape[:2] grid_new = xy_.clone() grid_new[..., 1] = xy_[..., 0] / 1.5 grid_new[..., 0] = xy_[..., 1] / 2.0 if len(canonical_img.shape) == 3: canonical_img = canonical_img.unsqueeze(0) results = torch.nn.functional.grid_sample( canonical_img.permute(0, 3, 1, 2), grid_new.unsqueeze(1).unsqueeze(0), mode='bilinear', padding_mode='border') o = results.squeeze().permute(1,0) if step == 0: myoutput = o else: myoutput = torch.cat([myoutput, o]) myoutput = myoutput.reshape(512, 512, data_len, 3).permute(2, 0, 1, 3).clone().detach().cpu().numpy().astype(np.float32) # myoutput = np.clip(myoutput, -1, 1) * 0.5 + 0.5 for i in range(len(myoutput)): myoutput[i] = Image.fromarray(np.uint8(myoutput[i])).resize((512, 512)) #854, 480 edit_video_path = f'NaRCan_fps_10.mp4' images_to_video(myoutput, edit_video_path) return edit_video_path @spaces.GPU def edit_with_pnp(input_video, prompt, num_steps, guidance_scale, seed, n_prompt, control_type="Lineart"): video_name = input_video.split('/')[-1] if video_name in video_to_image: image_path = video_to_image[video_name][0] pth_path = video_to_image[video_name][1] frames_path = video_to_image[video_name][2] else: return None if control_type == "Lineart": # Load the control net model for lineart controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_lineart", torch_dtype=torch.float16) pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 ) pipe.to(device) # lineart processor = LineartDetector.from_pretrained("lllyasviel/Annotators") processor_partial = partial(processor, coarse=False) size_ = 768 canonical_image = Image.open(image_path) ori_size = canonical_image.size image = processor_partial(canonical_image.resize((size_, size_)), detect_resolution=size_, image_resolution=size_) image = image.resize(ori_size, resample=Image.BILINEAR) generator = torch.manual_seed(seed) if seed != -1 else None output_images = pipe( prompt=prompt, image=image, num_inference_steps=num_steps, guidance_scale=guidance_scale, negative_prompt=n_prompt, generator=generator ).images # output_images[0] = output_images[0].resize(ori_size, resample=Image.BILINEAR) else: # Load the control net model for canny controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16) pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 ) pipe.to(device) # canny canonical_image = cv2.imread(image_path) canonical_image = cv2.cvtColor(canonical_image, cv2.COLOR_BGR2RGB) image = cv2.cvtColor(canonical_image, cv2.COLOR_RGB2GRAY) image = cv2.Canny(image, 100, 200) image = image[:, :, None] image = np.concatenate([image, image, image], axis=2) image = Image.fromarray(image) generator = torch.manual_seed(seed) if seed != -1 else None output_images = pipe( prompt=prompt, image=image, num_inference_steps=num_steps, guidance_scale=guidance_scale, negative_prompt=n_prompt, generator=generator ).images edit_video_path = NaRCan_make_video(output_images[0], pth_path, frames_path) # Here we return the first output image as the result return edit_video_path ######## # demo # ######## intro = """

NaRCan - Natural Refined Canonical Image

[Project page], [Paper]
Each edit takes ~10 sec
""" with gr.Blocks(css="style.css") as demo: gr.HTML(intro) frames = gr.State() inverted_latents = gr.State() latents = gr.State() zs = gr.State() do_inversion = gr.State(value=True) with gr.Row(): input_video = gr.Video(label="Input Video", interactive=False, elem_id="input_video", value='examples/bear.mp4', height=365, width=365) output_video = gr.Video(label="Edited Video", interactive=False, elem_id="output_video", height=365, width=365) # input_video.style(height=365, width=365) # output_video.style(height=365, width=365) with gr.Row(): prompt = gr.Textbox( label="Describe your edited video", max_lines=1, value="bear, Van Gogh Style" # placeholder="bear, Van Gogh Style" ) with gr.Row(): run_button = gr.Button("Edit your video!", visible=True) max_images = 12 default_num_images = 3 with gr.Accordion('Advanced options', open=False): control_type = gr.Dropdown( ["Canny", "Lineart"], label="Control Type", info="Canny or Lineart", value="Lineart" ) num_steps = gr.Slider(label='Steps', minimum=1, maximum=100, value=20, step=1) guidance_scale = gr.Slider(label='Guidance Scale', minimum=0.1, maximum=30.0, value=9.0, step=0.1) seed = gr.Slider(label='Seed', minimum=-1, maximum=2147483647, step=1, randomize=True) n_prompt = gr.Textbox( label='Negative Prompt', value="" ) input_video.change( fn = update_prompt, inputs = [input_video], outputs = [prompt], queue = False) run_button.click(fn = edit_with_pnp, inputs = [input_video, prompt, num_steps, guidance_scale, seed, n_prompt, control_type, ], outputs = [output_video] ) gr.Examples( examples=get_example(), label='Examples', inputs=[input_video], outputs=[output_video], examples_per_page=8 ) demo.queue() demo.launch()