NaRCan_demo / app.py
Koi953215's picture
update app.py
1902ec1
raw
history blame
No virus
13.1 kB
import gradio as gr
import spaces
import numpy as np
import torch
torch.jit.script = lambda f: f
import cv2
import os
import imageio
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from controlnet_aux import LineartDetector
from functools import partial
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from NaRCan_model import Homography, Siren
from util import get_mgrid, apply_homography, jacobian, VideoFitting, TestVideoFitting
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_example():
case = [
[
'examples/bear.mp4',
],
[
'examples/boat.mp4',
],
[
'examples/woman-drink.mp4',
],
[
'examples/corgi.mp4',
],
[
'examples/yacht.mp4',
],
[
'examples/koolshooters.mp4',
],
[
'examples/overlook-the-ocean.mp4',
],
[
'examples/rotate.mp4',
],
[
'examples/shark-ocean.mp4',
],
[
'examples/surf.mp4',
],
[
'examples/cactus.mp4',
],
[
'examples/gold-fish.mp4',
]
]
return case
def set_default_prompt(video_name):
video_to_prompt = {
'bear.mp4': 'bear, Van Gogh Style',
'boat.mp4': 'a burning boat sails on lava',
'cactus.mp4': 'cactus, made of paper',
'corgi.mp4': 'a hellhound',
'gold-fish.mp4': 'Goldfish in the Milky Way',
'koolshooters.mp4': 'Avatar',
'overlook-the-ocean.mp4': 'ocean, pixel style',
'rotate.mp4': 'turbine engine',
'shark-ocean.mp4': 'A sleek shark, cartoon style',
'surf.mp4': 'Sailing, The background is a large white cloud, sketch style',
'woman-drink.mp4': 'a drinking zombie',
'yacht.mp4': 'yacht, cyberpunk style',
}
return video_to_prompt.get(video_name, '')
def update_prompt(input_video):
video_name = input_video.split('/')[-1]
return set_default_prompt(video_name)
# Map videos to corresponding images
video_to_image = {
'bear.mp4': ['canonical/bear.png', 'pth_file/bear', 'examples_frames/bear'],
'boat.mp4': ['canonical/boat.png', 'pth_file/boat', 'examples_frames/boat'],
'cactus.mp4': ['canonical/cactus.png', 'pth_file/cactus', 'examples_frames/cactus'],
'corgi.mp4': ['canonical/corgi.png', 'pth_file/corgi', 'examples_frames/corgi'],
'gold-fish.mp4': ['canonical/gold-fish.png', 'pth_file/gold-fish', 'examples_frames/gold-fish'],
'koolshooters.mp4': ['canonical/koolshooters.png', 'pth_file/koolshooters', 'examples_frames/koolshooters'],
'overlook-the-ocean.mp4': ['canonical/overlook-the-ocean.png', 'pth_file/overlook-the-ocean', 'examples_frames/overlook-the-ocean'],
'rotate.mp4': ['canonical/rotate.png', 'pth_file/rotate', 'examples_frames/rotate'],
'shark-ocean.mp4': ['canonical/shark-ocean.png', 'pth_file/shark-ocean', 'examples_frames/shark-ocean'],
'surf.mp4': ['canonical/surf.png', 'pth_file/surf', 'examples_frames/surf'],
'woman-drink.mp4': ['canonical/woman-drink.png', 'pth_file/woman-drink', 'examples_frames/woman-drink'],
'yacht.mp4': ['canonical/yacht.png', 'pth_file/yacht', 'examples_frames/yacht'],
}
def images_to_video(image_list, output_path, fps=10):
# Convert PIL Images to numpy arrays
frames = [np.array(img).astype(np.uint8) for img in image_list]
frames = frames[:20]
# Create video writer
writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
for frame in frames:
writer.append_data(frame)
writer.close()
@spaces.GPU(duration=120)
def NaRCan_make_video(edit_canonical, pth_path, frames_path):
# load NaRCan model
checkpoint_g_old = torch.load(os.path.join(pth_path, "homography_g.pth"))
checkpoint_g = torch.load(os.path.join(pth_path, "mlp_g.pth"))
g_old = Homography(hidden_features=256, hidden_layers=2).to(device)
g = Siren(in_features=3, out_features=2, hidden_features=256,
hidden_layers=5, outermost_linear=True).to(device)
g_old.load_state_dict(checkpoint_g_old)
g.load_state_dict(checkpoint_g)
g_old.eval()
g.eval()
transform = Compose([
Resize(512),
ToTensor(),
Normalize(torch.Tensor([0.5, 0.5, 0.5]), torch.Tensor([0.5, 0.5, 0.5]))
])
v = TestVideoFitting(frames_path, transform)
videoloader = DataLoader(v, batch_size=1, pin_memory=True, num_workers=0)
model_input, ground_truth = next(iter(videoloader))
model_input, ground_truth = model_input[0].to(device), ground_truth[0].to(device)
myoutput = None
data_len = len(os.listdir(frames_path))
with torch.no_grad():
batch_size = (v.H * v.W)
for step in range(data_len):
start = (step * batch_size) % len(model_input)
end = min(start + batch_size, len(model_input))
# get the deformation
xy, t = model_input[start:end, :-1], model_input[start:end, [-1]]
xyt = model_input[start:end]
h_old = apply_homography(xy, g_old(t))
h = g(xyt)
xy_ = h_old + h
# use canonical to reconstruct
w, h = v.W, v.H
canonical_img = np.array(edit_canonical.convert('RGB'))
canonical_img = torch.from_numpy(canonical_img).float().to(device)
h_c, w_c = canonical_img.shape[:2]
grid_new = xy_.clone()
grid_new[..., 1] = xy_[..., 0] / 1.5
grid_new[..., 0] = xy_[..., 1] / 2.0
if len(canonical_img.shape) == 3:
canonical_img = canonical_img.unsqueeze(0)
results = torch.nn.functional.grid_sample(
canonical_img.permute(0, 3, 1, 2),
grid_new.unsqueeze(1).unsqueeze(0),
mode='bilinear',
padding_mode='border')
o = results.squeeze().permute(1,0)
if step == 0:
myoutput = o
else:
myoutput = torch.cat([myoutput, o])
myoutput = myoutput.reshape(512, 512, data_len, 3).permute(2, 0, 1, 3).clone().detach().cpu().numpy().astype(np.float32)
# myoutput = np.clip(myoutput, -1, 1) * 0.5 + 0.5
for i in range(len(myoutput)):
myoutput[i] = Image.fromarray(np.uint8(myoutput[i])).resize((512, 512)) #854, 480
edit_video_path = f'NaRCan_fps_10.mp4'
images_to_video(myoutput, edit_video_path)
return edit_video_path
@spaces.GPU(duration=120)
def edit_with_pnp(input_video, prompt, num_steps, guidance_scale, seed, n_prompt, control_type="Lineart"):
video_name = input_video.split('/')[-1]
if video_name in video_to_image:
image_path = video_to_image[video_name][0]
pth_path = video_to_image[video_name][1]
frames_path = video_to_image[video_name][2]
else:
return None
if control_type == "Lineart":
# Load the control net model for lineart
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_lineart", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.to(device)
# lineart
processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
processor_partial = partial(processor, coarse=False)
size_ = 768
canonical_image = Image.open(image_path)
ori_size = canonical_image.size
image = processor_partial(canonical_image.resize((size_, size_)), detect_resolution=size_, image_resolution=size_)
image = image.resize(ori_size, resample=Image.BILINEAR)
generator = torch.manual_seed(seed) if seed != -1 else None
output_images = pipe(
prompt=prompt,
image=image,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
negative_prompt=n_prompt,
generator=generator
).images
# output_images[0] = output_images[0].resize(ori_size, resample=Image.BILINEAR)
else:
# Load the control net model for canny
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.to(device)
# canny
canonical_image = cv2.imread(image_path)
canonical_image = cv2.cvtColor(canonical_image, cv2.COLOR_BGR2RGB)
image = cv2.cvtColor(canonical_image, cv2.COLOR_RGB2GRAY)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
image = Image.fromarray(image)
generator = torch.manual_seed(seed) if seed != -1 else None
output_images = pipe(
prompt=prompt,
image=image,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
negative_prompt=n_prompt,
generator=generator
).images
edit_video_path = NaRCan_make_video(output_images[0], pth_path, frames_path)
# Here we return the first output image as the result
return edit_video_path
########
# demo #
########
intro = """
<div style="text-align:center">
<h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
NaRCan - <small>Natural Refined Canonical Image</small>
</h1>
<span>[<a target="_blank" href="https://koi953215.github.io/NaRCan_page/">Project page</a>], [<a target="_blank" href="https://huggingface.co/papers/2406.06523">Paper</a>]</span>
<div style="display:flex; justify-content: center;margin-top: 0.5em">Each edit takes 30 sec ~ 2 min </div>
</div>
"""
with gr.Blocks(css="style.css") as demo:
gr.HTML(intro)
frames = gr.State()
inverted_latents = gr.State()
latents = gr.State()
zs = gr.State()
do_inversion = gr.State(value=True)
with gr.Row():
input_video = gr.Video(label="Input Video", interactive=False, elem_id="input_video", value='examples/bear.mp4', height=365, width=365)
output_video = gr.Video(label="Edited Video", interactive=False, elem_id="output_video", height=365, width=365)
# input_video.style(height=365, width=365)
# output_video.style(height=365, width=365)
with gr.Row():
prompt = gr.Textbox(
label="Describe your edited video",
max_lines=1,
value="bear, Van Gogh Style"
# placeholder="bear, Van Gogh Style"
)
with gr.Row():
run_button = gr.Button("Edit your video!", visible=True)
max_images = 12
default_num_images = 3
with gr.Accordion('Advanced options', open=False):
control_type = gr.Dropdown(
["Canny", "Lineart"],
label="Control Type",
info="Canny or Lineart",
value="Lineart"
)
num_steps = gr.Slider(label='Steps',
minimum=1,
maximum=100,
value=20,
step=1)
guidance_scale = gr.Slider(label='Guidance Scale',
minimum=0.1,
maximum=30.0,
value=9.0,
step=0.1)
seed = gr.Slider(label='Seed',
minimum=-1,
maximum=2147483647,
step=1,
randomize=True)
n_prompt = gr.Textbox(
label='Negative Prompt',
value=""
)
input_video.change(
fn = update_prompt,
inputs = [input_video],
outputs = [prompt],
queue = False)
run_button.click(fn = edit_with_pnp,
inputs = [input_video,
prompt,
num_steps,
guidance_scale,
seed,
n_prompt,
control_type,
],
outputs = [output_video]
)
gr.Examples(
examples=get_example(),
label='Examples',
inputs=[input_video],
outputs=[output_video],
examples_per_page=8
)
demo.queue()
demo.launch()