File size: 3,878 Bytes
280b585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b79be99
280b585
 
 
 
 
 
 
7354fac
fa9ef65
8e36b4d
280b585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70b8e29
 
5c50171
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import torch

import imageio
import imageio_ffmpeg
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from skimage.transform import resize
import warnings
import os
from model import load_checkpoints
from model import make_animation
from skimage import img_as_ubyte
from PIL import Image
import time
warnings.filterwarnings("ignore")

device = torch.device('cuda:0')
#device = torch.device('cpu')


dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif']
source_image_path = './assets/source.png'
driving_video_path = './assets/driving.mp4'
output_video_path = './generated.mp4'
config_path = './config/vox-256.yaml'

checkpoint_path = './checkpoints/vox.pth.tar'

predict_mode = 'relative' # ['standard', 'relative', 'avd']
find_best_frame = False # when use the relative mode to animate a face, use 'find_best_frame=True' can get better quality result

pixel = 256 # for vox, taichi and mgif, the resolution is 256*256
if(dataset_name == 'ted'): # for ted, the resolution is 384*384
    pixel = 384

if find_best_frame:
  #!pip install face_alignment
  pass


def create_video(tt):

    source_image = imageio.imread(f"assets/img_{tt}.jpg")
    reader = imageio.get_reader(f"assets/ref_{tt}.mp4")

    source_image = resize(source_image, (pixel, pixel))[..., :3]

    fps = reader.get_meta_data()['fps']
    driving_video = []
    try:
        for im in reader:
            driving_video.append(im)
    except RuntimeError:
        pass
    reader.close()

    driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]

    def display(source, driving, generated=None):
        fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))

        ims = []
        for i in range(len(driving)):
            cols = [source]
            cols.append(driving[i])
            if generated is not None:
                cols.append(generated[i])
            im = plt.imshow(np.concatenate(cols, axis=1), animated=True)
            plt.axis('off')
            ims.append([im])

        ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)
        plt.close()
        return ani
        

    #HTML(display(source_image, driving_video).to_html5_video())
    inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)



    if predict_mode=='relative' and find_best_frame:
        from model import find_best_frame as _find
        i = _find(source_image, driving_video, device.type=='cpu')
        print ("Best frame: " + str(i))
        driving_forward = driving_video[i:]
        driving_backward = driving_video[:(i+1)][::-1]
        predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
        predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
        predictions = predictions_backward[::-1] + predictions_forward[1:]
    else:
        predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)

    #save resulting video
    imageio.mimsave(f"./assets/output_{tt}.mp4", [img_as_ubyte(frame) for frame in predictions], fps=fps)
    
    
def greet(img,video):
    tt=str(time.time())
    os.replace(video, f"assets/ref_{tt}.mp4")
    img.save(f"assets/img_{tt}.jpg")
    create_video(tt)
    return f"./assets/output_{tt}.mp4"


iface = gr.Interface(fn=greet, inputs=
                     [gr.inputs.Image(type="pil",label="Foto yükleyin"),gr.inputs.Video(label="Video yükleyin")], outputs=gr.inputs.Video(label="Yeni video"))
iface.launch()