|
import cv2 |
|
import glob |
|
import torch |
|
import numpy as np |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
|
|
from networks.amts import Model as AMTS |
|
from networks.amtl import Model as AMTL |
|
from networks.amtg import Model as AMTG |
|
from utils import img2tensor, tensor2img, InputPadder |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model_dict = { |
|
'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG |
|
} |
|
|
|
|
|
def vid2vid(model_type, video, iters): |
|
model = model_dict[model_type]() |
|
model.to(device) |
|
ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth') |
|
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) |
|
model.load_state_dict(ckpt['state_dict']) |
|
model.eval() |
|
vcap = cv2.VideoCapture(video) |
|
ori_frame_rate = vcap.get(cv2.CAP_PROP_FPS) |
|
inputs = [] |
|
h = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
w = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
if device == 'cuda': |
|
anchor_resolution = 1024 * 512 |
|
anchor_memory = 1500 * 1024**2 |
|
anchor_memory_bias = 2500 * 1024**2 |
|
vram_avail = torch.cuda.get_device_properties(device).total_memory |
|
else: |
|
|
|
anchor_resolution = 8192*8192 |
|
anchor_memory = 1 |
|
anchor_memory_bias = 0 |
|
vram_avail = 1 |
|
|
|
scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory) |
|
scale = 1 if scale > 1 else scale |
|
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16 |
|
if scale < 1: |
|
print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}") |
|
padding = int(16 / scale) |
|
padder = InputPadder((h, w), padding) |
|
while True: |
|
ret, frame = vcap.read() |
|
if ret is False: |
|
break |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frame_t = img2tensor(frame).to(device) |
|
frame_t = padder.pad(frame_t) |
|
inputs.append(frame_t) |
|
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) |
|
|
|
for i in range(iters): |
|
print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}') |
|
outputs = [inputs[0]] |
|
for in_0, in_1 in zip(inputs[:-1], inputs[1:]): |
|
with torch.no_grad(): |
|
imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred'] |
|
imgt_pred = padder.unpad(imgt_pred) |
|
in_1 = padder.unpad(in_1) |
|
outputs += [imgt_pred, in_1] |
|
inputs = outputs |
|
|
|
out_path = 'results' |
|
size = outputs[0].shape[2:][::-1] |
|
writer = cv2.VideoWriter(f'{out_path}/demo_vfi.mp4', |
|
cv2.VideoWriter_fourcc(*'mp4v'), |
|
ori_frame_rate * 2 ** iters, size) |
|
for i, imgt_pred in enumerate(outputs): |
|
imgt_pred = tensor2img(imgt_pred) |
|
imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR) |
|
writer.write(imgt_pred) |
|
writer.release() |
|
return 'results/demo_vfi.mp4' |
|
|
|
|
|
def demo_vid(): |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown('## Video Demo') |
|
with gr.Row(): |
|
gr.HTML( |
|
""" |
|
<div style="text-align: left; auto;"> |
|
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem"> |
|
Description: You can increase the frame rate of the video by 2 times, 4 times, or 8 times. (The video should be less than 10 seconds.) |
|
</h3> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
video = gr.Video(label='Video Input') |
|
with gr.Column(): |
|
result = gr.Video(label="Generated Video") |
|
with gr.Accordion('Advanced options', open=False): |
|
ratio = gr.Slider(label='Multiple Ratio', |
|
minimum=1, |
|
maximum=4, |
|
value=2, |
|
step=1) |
|
model_type = gr.Radio(['AMT-S', 'AMT-L', 'AMT-G'], |
|
label='Model Select', |
|
value='AMT-S') |
|
run_button = gr.Button(label='Run') |
|
inputs = [ |
|
model_type, |
|
video, |
|
ratio, |
|
] |
|
|
|
gr.Examples(examples=glob.glob("examples/*.mp4"), |
|
inputs=video, |
|
label='Example videos (drag them to the input window)', |
|
run_on_click=False, |
|
) |
|
|
|
run_button.click(fn=vid2vid, |
|
inputs=inputs, |
|
outputs=result,) |
|
return demo |