|
import cv2 |
|
import glob |
|
import torch |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
|
|
from networks.amts import Model as AMTS |
|
from networks.amtl import Model as AMTL |
|
from networks.amtg import Model as AMTG |
|
from utils import img2tensor, tensor2img, InputPadder |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model_dict = { |
|
'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG |
|
} |
|
|
|
def img2vid(model_type, img0, img1, frame_ratio, iters): |
|
model = model_dict[model_type]() |
|
model.to(device) |
|
ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth') |
|
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) |
|
model.load_state_dict(ckpt['state_dict']) |
|
model.eval() |
|
img0_t = img2tensor(img0).to(device) |
|
img1_t = img2tensor(img1).to(device) |
|
padder = InputPadder(img0_t.shape, 16) |
|
img0_t, img1_t = padder.pad(img0_t, img1_t) |
|
inputs = [img0_t, img1_t] |
|
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) |
|
|
|
for i in range(iters): |
|
print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}') |
|
outputs = [img0_t] |
|
for in_0, in_1 in zip(inputs[:-1], inputs[1:]): |
|
with torch.no_grad(): |
|
imgt_pred = model(in_0, in_1, embt, eval=True)['imgt_pred'] |
|
imgt_pred = padder.unpad(imgt_pred) |
|
in_1 = padder.unpad(in_1) |
|
outputs += [imgt_pred, in_1] |
|
inputs = outputs |
|
|
|
out_path = 'results' |
|
size = outputs[0].shape[2:][::-1] |
|
writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size) |
|
for i, imgt_pred in enumerate(outputs): |
|
imgt_pred = tensor2img(imgt_pred) |
|
imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR) |
|
writer.write(imgt_pred) |
|
writer.release() |
|
return 'results/demo.mp4' |
|
|
|
|
|
def demo_img(): |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown('## Image Demo') |
|
with gr.Row(): |
|
gr.HTML( |
|
""" |
|
<div style="text-align: left; auto;"> |
|
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem"> |
|
Description: With 2 input images, you can generate a short video from them. |
|
</h3> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
img0 = gr.Image(label='Image0') |
|
img1 = gr.Image(label='Image1') |
|
with gr.Column(): |
|
result = gr.Video(label="Generated Video") |
|
with gr.Accordion('Advanced options', open=False): |
|
ratio = gr.Slider(label='Multiple Ratio', |
|
minimum=4, |
|
maximum=7, |
|
value=6, |
|
step=1) |
|
frame_ratio = gr.Slider(label='Frame Ratio', |
|
minimum=8, |
|
maximum=64, |
|
value=16, |
|
step=1) |
|
model_type = gr.Radio(['AMT-S', 'AMT-L', 'AMT-G'], |
|
label='Model Select', |
|
value='AMT-S') |
|
run_button = gr.Button(label='Run') |
|
inputs = [ |
|
model_type, |
|
img0, |
|
img1, |
|
frame_ratio, |
|
ratio, |
|
] |
|
|
|
gr.Examples(examples=glob.glob("examples/*.png"), |
|
inputs=img0, |
|
label='Example images (drag them to input windows)', |
|
run_on_click=False, |
|
) |
|
|
|
run_button.click(fn=img2vid, |
|
inputs=inputs, |
|
outputs=result,) |
|
return demo |