import cv2 import glob import torch import gradio as gr import numpy as np from huggingface_hub import hf_hub_download from networks.amts import Model as AMTS from networks.amtl import Model as AMTL from networks.amtg import Model as AMTG from utils import ( img2tensor, tensor2img, InputPadder, check_dim_and_resize ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_dict = { 'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG } def img2vid(model_type, img0, img1, frame_ratio, iters): model = model_dict[model_type]() model.to(device) ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth') ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) model.load_state_dict(ckpt['state_dict']) model.eval() img0_t = img2tensor(img0).to(device) img1_t = img2tensor(img1).to(device) inputs = [img0_t, img1_t] if device == 'cuda': anchor_resolution = 1024 * 512 anchor_memory = 1500 * 1024**2 anchor_memory_bias = 2500 * 1024**2 vram_avail = torch.cuda.get_device_properties(device).total_memory else: # Do not resize in cpu mode anchor_resolution = 8192*8192 anchor_memory = 1 anchor_memory_bias = 0 vram_avail = 1 embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) inputs = check_dim_and_resize(inputs) h, w = inputs[0].shape[-2:] scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory) scale = 1 if scale > 1 else scale scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16 if scale < 1: print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}") padding = int(16 / scale) padder = InputPadder(inputs[0].shape, padding) inputs = padder.pad(*inputs) for i in range(iters): print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}') outputs = [inputs[0]] for in_0, in_1 in zip(inputs[:-1], inputs[1:]): in_0 = in_0.to(device) in_1 = in_1.to(device) with torch.no_grad(): imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred'] outputs += [imgt_pred.cpu(), in_1.cpu()] inputs = outputs outputs = padder.unpad(*outputs) out_path = 'results' size = outputs[0].shape[2:][::-1] writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size) for i, imgt_pred in enumerate(outputs): imgt_pred = tensor2img(imgt_pred) imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR) writer.write(imgt_pred) writer.release() return 'results/demo.mp4' def demo_img(): with gr.Blocks() as demo: with gr.Row(): gr.Markdown('## Image Demo') with gr.Row(): gr.HTML( """

Description: With 2 input images, you can generate a short video from them.

""") with gr.Row(): with gr.Column(): img0 = gr.Image(label='Image0') img1 = gr.Image(label='Image1') with gr.Column(): result = gr.Video(label="Generated Video") with gr.Accordion('Advanced options', open=False): ratio = gr.Slider(label='Multiple Ratio', minimum=4, maximum=7, value=6, step=1) frame_ratio = gr.Slider(label='Frame Ratio', minimum=8, maximum=64, value=16, step=1) model_type = gr.Radio(['AMT-S', 'AMT-L', 'AMT-G'], label='Model Select', value='AMT-S') run_button = gr.Button(label='Run') inputs = [ model_type, img0, img1, frame_ratio, ratio, ] gr.Examples(examples=glob.glob("examples/*.png"), inputs=img0, label='Example images (drag them to input windows)', run_on_click=False, ) run_button.click(fn=img2vid, inputs=inputs, outputs=result,) return demo