fd_fomm / app.py
Farazquraishi's picture
Update app.py
b84eea2
import gradio as gr
import subprocess
import yaml
from tqdm import tqdm
import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
import torch
from sync_batchnorm import DataParallelWithCallback
from modules.generator import OcclusionAwareGenerator
from modules.keypoint_detector import KPDetector
from animate import normalize_kp
def load_checkpoints(config_path, checkpoint_path, cpu=False):
with open(config_path) as f:
config = yaml.load(f)
generator = OcclusionAwareGenerator(
**config["model_params"]["generator_params"], **config["model_params"]["common_params"]
)
if not cpu:
generator.cuda()
kp_detector = KPDetector(**config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"])
if not cpu:
kp_detector.cuda()
if cpu:
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
else:
checkpoint = torch.load(checkpoint_path)
generator.load_state_dict(checkpoint["generator"])
kp_detector.load_state_dict(checkpoint["kp_detector"])
if not cpu:
generator = DataParallelWithCallback(generator)
kp_detector = DataParallelWithCallback(kp_detector)
generator.eval()
kp_detector.eval()
return generator, kp_detector
def make_animation(
source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False
):
with torch.no_grad():
predictions = []
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
if not cpu:
source = source.cuda()
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
kp_source = kp_detector(source)
kp_driving_initial = kp_detector(driving[:, :, 0])
for frame_idx in tqdm(range(driving.shape[2])):
driving_frame = driving[:, :, frame_idx]
if not cpu:
driving_frame = driving_frame.cuda()
kp_driving = kp_detector(driving_frame)
kp_norm = normalize_kp(
kp_source=kp_source,
kp_driving=kp_driving,
kp_driving_initial=kp_driving_initial,
use_relative_movement=relative,
use_relative_jacobian=relative,
adapt_movement_scale=adapt_movement_scale,
)
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
predictions.append(np.transpose(out["prediction"].data.cpu().numpy(), [0, 2, 3, 1])[0])
return predictions
def inference(video, image):
# trim video to 8 seconds
cmd = f"ffmpeg -y -ss 00:00:00 -i {video} -to 00:00:08 -c copy video_input.mp4"
subprocess.run(cmd.split())
video = "video_input.mp4"
source_image = imageio.imread(image)
reader = imageio.get_reader(video)
fps = reader.get_meta_data()["fps"]
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()
source_image = resize(source_image, (256, 256))[..., :3]
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
predictions = make_animation(
source_image,
driving_video,
generator,
kp_detector,
relative=True,
adapt_movement_scale=True,
cpu=True,
)
imageio.mimsave("result.mp4", [img_as_ubyte(frame) for frame in predictions], fps=fps)
imageio.mimsave("driving.mp4", [img_as_ubyte(frame) for frame in driving_video], fps=fps)
cmd = f"ffmpeg -y -i result.mp4 -i {video} -c copy -map 0:0 -map 1:0 -shortest final.mp4"
subprocess.run(cmd.split())
return "final.mp4"
title = "FD_FOMM"
description = "NOT Working"
article = "<p style='text-align: center'><a href='https://papers.nips.cc/paper/2019/file/31c0b36aef265d9221af80872ceb62f9-Paper.pdf' target='_blank'>First Order Motion Model for Image Animation</a> | <a href='https://github.com/AliaksandrSiarohin/first-order-model' target='_blank'>Github Repo</a></p>"
examples = [["vid.mp4", "julien.png"]]
generator, kp_detector = load_checkpoints(
config_path="config/vox-256.yaml",
checkpoint_path="weights/vox-adv-cpk.pth.tar",
cpu=True,
)
iface = gr.Interface(
inference,
[
gr.inputs.Video(type="mp4"),
gr.inputs.Image(type="filepath"),
],
outputs=gr.outputs.Video(label="Output Video"),
examples=examples,
enable_queue=True,
title=title,
article=article,
description=description,
)
iface.launch(debug=True)