File size: 4,782 Bytes
7062b81
 
 
 
981d2df
7062b81
 
 
 
 
2bbc3ee
 
 
 
 
506e597
7062b81
 
 
 
 
 
 
 
77c21de
 
7062b81
 
 
 
2bbc3ee
4fa279e
 
 
 
 
 
b5a9347
2bbc3ee
 
 
 
 
7062b81
 
2bbc3ee
 
 
 
 
 
 
 
 
 
 
7062b81
 
2bbc3ee
7062b81
2bbc3ee
 
7062b81
2bbc3ee
 
7062b81
2bbc3ee
7062b81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import cv2
import glob
import torch
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download

from networks.amts import Model as AMTS
from networks.amtl import Model as AMTL
from networks.amtg import Model as AMTG
from utils import (
    img2tensor, tensor2img, 
    InputPadder, 
    check_dim_and_resize
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = {
    'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
}

def img2vid(model_type, img0, img1, frame_ratio, iters):
    model = model_dict[model_type]()
    model.to(device)
    ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth')
    ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
    model.load_state_dict(ckpt['state_dict'])
    model.eval()
    img0_t = img2tensor(img0).to(device)
    img1_t = img2tensor(img1).to(device)
    inputs = [img0_t, img1_t]
    
    
    if device == 'cuda':
        anchor_resolution = 1024 * 512
        anchor_memory = 1500 * 1024**2
        anchor_memory_bias = 2500 * 1024**2
        vram_avail = torch.cuda.get_device_properties(device).total_memory
    else:
        # Do not resize in cpu mode
        anchor_resolution = 8192*8192
        anchor_memory = 1
        anchor_memory_bias = 0
        vram_avail = 1
    embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)

    inputs = check_dim_and_resize(inputs)
    h, w = inputs[0].shape[-2:]
    scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)
    scale = 1 if scale > 1 else scale
    scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
    if scale < 1:
        print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}")
    padding = int(16 / scale)
    padder = InputPadder(inputs[0].shape, padding)
    inputs = padder.pad(*inputs)

    for i in range(iters):
        print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
        outputs = [inputs[0]]
        for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
            in_0 = in_0.to(device)
            in_1 = in_1.to(device)
            with torch.no_grad():
                imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred']
            outputs += [imgt_pred.cpu(), in_1.cpu()]
        inputs = outputs
    outputs = padder.unpad(*outputs)
    out_path = 'results'
    size = outputs[0].shape[2:][::-1]
    writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size)
    for i, imgt_pred in enumerate(outputs):
        imgt_pred = tensor2img(imgt_pred)
        imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR)
        writer.write(imgt_pred)
    writer.release()
    return 'results/demo.mp4'

    
def demo_img():
    with gr.Blocks() as demo:
        with gr.Row():
            gr.Markdown('## Image Demo')
        with gr.Row():
            gr.HTML(
                """
                <div style="text-align: left; auto;">
                <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
                    Description: With 2 input images, you can generate a short video from them.
                </h3>
                </div>
                """)

        with gr.Row():
            with gr.Column():
                img0 = gr.Image(label='Image0')
                img1 = gr.Image(label='Image1')
            with gr.Column():
                result = gr.Video(label="Generated Video")
                with gr.Accordion('Advanced options', open=False):
                    ratio = gr.Slider(label='Multiple Ratio',
                                     minimum=4,
                                     maximum=7,
                                     value=6,
                                     step=1)
                    frame_ratio = gr.Slider(label='Frame Ratio',
                                     minimum=8,
                                     maximum=64,
                                     value=16,
                                     step=1)
                    model_type = gr.Radio(['AMT-S', 'AMT-L', 'AMT-G'],
                                          label='Model Select',
                                          value='AMT-S')
                run_button = gr.Button(label='Run')
        inputs = [
            model_type,
            img0,
            img1,
            frame_ratio,
            ratio,
        ]

        gr.Examples(examples=glob.glob("examples/*.png"),
                inputs=img0,
                label='Example images (drag them to input windows)',
                run_on_click=False,
        )

        run_button.click(fn=img2vid,
                         inputs=inputs,
                         outputs=result,)
    return demo