AMT / demo_img.py
lalala125's picture
Update demo_img.py
506e597
raw
history blame
3.88 kB
import cv2
import glob
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
from networks.amts import Model as AMTS
from networks.amtl import Model as AMTL
from networks.amtg import Model as AMTG
from utils import img2tensor, tensor2img, InputPadder
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = {
'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
}
def img2vid(model_type, img0, img1, frame_ratio, iters):
model = model_dict[model_type]()
model.to(device)
ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth')
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
model.load_state_dict(ckpt['state_dict'])
model.eval()
img0_t = img2tensor(img0).to(device)
img1_t = img2tensor(img1).to(device)
padder = InputPadder(img0_t.shape, 16)
img0_t, img1_t = padder.pad(img0_t, img1_t)
inputs = [img0_t, img1_t]
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
for i in range(iters):
print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
outputs = [img0_t]
for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
with torch.no_grad():
imgt_pred = model(in_0, in_1, embt, eval=True)['imgt_pred']
imgt_pred = padder.unpad(imgt_pred)
in_1 = padder.unpad(in_1)
outputs += [imgt_pred, in_1]
inputs = outputs
out_path = 'results'
size = outputs[0].shape[2:][::-1]
writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size)
for i, imgt_pred in enumerate(outputs):
imgt_pred = tensor2img(imgt_pred)
imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR)
writer.write(imgt_pred)
writer.release()
return 'results/demo.mp4'
def demo_img():
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown('## Image Demo')
with gr.Row():
gr.HTML(
"""
<div style="text-align: left; auto;">
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
Description: With 2 input images, you can generate a short video from them.
</h3>
</div>
""")
with gr.Row():
with gr.Column():
img0 = gr.Image(label='Image0')
img1 = gr.Image(label='Image1')
with gr.Column():
result = gr.Video(label="Generated Video")
with gr.Accordion('Advanced options', open=False):
ratio = gr.Slider(label='Multiple Ratio',
minimum=4,
maximum=7,
value=6,
step=1)
frame_ratio = gr.Slider(label='Frame Ratio',
minimum=8,
maximum=64,
value=16,
step=1)
model_type = gr.Radio(['AMT-S', 'AMT-L', 'AMT-G'],
label='Model Select',
value='AMT-S')
run_button = gr.Button(label='Run')
inputs = [
model_type,
img0,
img1,
frame_ratio,
ratio,
]
gr.Examples(examples=glob.glob("examples/*.png"),
inputs=img0,
label='Example images (drag them to input windows)',
run_on_click=False,
)
run_button.click(fn=img2vid,
inputs=inputs,
outputs=result,)
return demo