Spaces:
Build error
Build error
import gradio as gr | |
import subprocess | |
import yaml | |
from tqdm import tqdm | |
import imageio | |
import numpy as np | |
from skimage.transform import resize | |
from skimage import img_as_ubyte | |
import torch | |
from sync_batchnorm import DataParallelWithCallback | |
from modules.generator import OcclusionAwareGenerator | |
from modules.keypoint_detector import KPDetector | |
from animate import normalize_kp | |
def load_checkpoints(config_path, checkpoint_path, cpu=False): | |
with open(config_path) as f: | |
config = yaml.load(f) | |
generator = OcclusionAwareGenerator( | |
**config["model_params"]["generator_params"], **config["model_params"]["common_params"] | |
) | |
if not cpu: | |
generator.cuda() | |
kp_detector = KPDetector(**config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"]) | |
if not cpu: | |
kp_detector.cuda() | |
if cpu: | |
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) | |
else: | |
checkpoint = torch.load(checkpoint_path) | |
generator.load_state_dict(checkpoint["generator"]) | |
kp_detector.load_state_dict(checkpoint["kp_detector"]) | |
if not cpu: | |
generator = DataParallelWithCallback(generator) | |
kp_detector = DataParallelWithCallback(kp_detector) | |
generator.eval() | |
kp_detector.eval() | |
return generator, kp_detector | |
def make_animation( | |
source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False | |
): | |
with torch.no_grad(): | |
predictions = [] | |
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) | |
if not cpu: | |
source = source.cuda() | |
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3) | |
kp_source = kp_detector(source) | |
kp_driving_initial = kp_detector(driving[:, :, 0]) | |
for frame_idx in tqdm(range(driving.shape[2])): | |
driving_frame = driving[:, :, frame_idx] | |
if not cpu: | |
driving_frame = driving_frame.cuda() | |
kp_driving = kp_detector(driving_frame) | |
kp_norm = normalize_kp( | |
kp_source=kp_source, | |
kp_driving=kp_driving, | |
kp_driving_initial=kp_driving_initial, | |
use_relative_movement=relative, | |
use_relative_jacobian=relative, | |
adapt_movement_scale=adapt_movement_scale, | |
) | |
out = generator(source, kp_source=kp_source, kp_driving=kp_norm) | |
predictions.append(np.transpose(out["prediction"].data.cpu().numpy(), [0, 2, 3, 1])[0]) | |
return predictions | |
def inference(video, image): | |
source_image = imageio.imread(image) | |
reader = imageio.get_reader(video) | |
fps = reader.get_meta_data()["fps"] | |
driving_video = [] | |
try: | |
for im in reader: | |
driving_video.append(im) | |
except RuntimeError: | |
pass | |
reader.close() | |
source_image = resize(source_image, (256, 256))[..., :3] | |
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video] | |
predictions = make_animation( | |
source_image, | |
driving_video, | |
generator, | |
kp_detector, | |
relative=True, | |
adapt_movement_scale=True, | |
cpu=True, | |
) | |
imageio.mimsave("result.mp4", [img_as_ubyte(frame) for frame in predictions], fps=fps) | |
cmd = f"ffmpeg -y -i result.mp4 -i {video} -c copy -map 0:0 -map 1:1 -shortest out.mp4" | |
subprocess.run(cmd.split()) | |
return "out.mp4" | |
title = "First Order Motion Model" | |
description = "Gradio demo for First Order Motion Model. Read more at the links below." | |
article = "<p style='text-align: center'><a href='https://papers.nips.cc/paper/2019/file/31c0b36aef265d9221af80872ceb62f9-Paper.pdf' target='_blank'>First Order Motion Model for Image Animation</a> | <a href='https://github.com/AliaksandrSiarohin/first-order-model' target='_blank'>Github Repo</a></p>" | |
examples = [["bella_porch.mp4", "julien.png"]] | |
generator, kp_detector = load_checkpoints( | |
config_path="config/vox-256.yaml", | |
checkpoint_path="weights/vox-adv-cpk.pth.tar", | |
cpu=True, | |
) | |
iface = gr.Interface( | |
inference, | |
[ | |
gr.inputs.Video(type="mp4"), | |
gr.inputs.Image(type="filepath"), | |
], | |
outputs=gr.outputs.Video(label="Output Video"), | |
examples=examples, | |
enable_queue=True, | |
title=title, | |
article=article, | |
description=description, | |
server_name="0.0.0.0", | |
) | |
iface.launch(debug=True) | |