AMT / demo_vid.py
zzl
release
1947ad8
raw history blame
No virus
4.73 kB
import cv2
import glob
import torch
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
from networks.amts import Model as AMTS
from networks.amtl import Model as AMTL
from networks.amtg import Model as AMTG
from utils import img2tensor, tensor2img, InputPadder
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = {
'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
}
def vid2vid(model_type, video, iters):
model = model_dict[model_type]()
model.to(device)
ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth')
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
model.load_state_dict(ckpt['state_dict'])
model.eval()
vcap = cv2.VideoCapture(video)
ori_frame_rate = vcap.get(cv2.CAP_PROP_FPS)
inputs = []
h = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH))
w = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
if device == 'cuda':
anchor_resolution = 1024 * 512
anchor_memory = 1500 * 1024**2
anchor_memory_bias = 2500 * 1024**2
vram_avail = torch.cuda.get_device_properties(device).total_memory
else:
# Do not resize in cpu mode
anchor_resolution = 8192*8192
anchor_memory = 1
anchor_memory_bias = 0
vram_avail = 1
scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)
scale = 1 if scale > 1 else scale
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
if scale < 1:
print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}")
padding = int(16 / scale)
padder = InputPadder((h, w), padding)
while True:
ret, frame = vcap.read()
if ret is False:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_t = img2tensor(frame).to(device)
frame_t = padder.pad(frame_t)
inputs.append(frame_t)
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
for i in range(iters):
print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
outputs = [inputs[0]]
for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
with torch.no_grad():
imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred']
imgt_pred = padder.unpad(imgt_pred)
in_1 = padder.unpad(in_1)
outputs += [imgt_pred, in_1]
inputs = outputs
out_path = 'results'
size = outputs[0].shape[2:][::-1]
writer = cv2.VideoWriter(f'{out_path}/demo_vfi.mp4',
cv2.VideoWriter_fourcc(*'mp4v'),
ori_frame_rate * 2 ** iters, size)
for i, imgt_pred in enumerate(outputs):
imgt_pred = tensor2img(imgt_pred)
imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR)
writer.write(imgt_pred)
writer.release()
return 'results/demo_vfi.mp4'
def demo_vid():
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown('## Video Demo')
with gr.Row():
gr.HTML(
"""
<div style="text-align: left; auto;">
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
Description: You can increase the frame rate of the video by 2 times, 4 times, or 8 times. (The video should be less than 10 seconds.)
</h3>
</div>
""")
with gr.Row():
with gr.Column():
video = gr.Video(label='Video Input')
with gr.Column():
result = gr.Video(label="Generated Video")
with gr.Accordion('Advanced options', open=False):
ratio = gr.Slider(label='Multiple Ratio',
minimum=1,
maximum=4,
value=2,
step=1)
model_type = gr.Radio(['AMT-S', 'AMT-L', 'AMT-G'],
label='Model Select',
value='AMT-S')
run_button = gr.Button(label='Run')
inputs = [
model_type,
video,
ratio,
]
gr.Examples(examples=glob.glob("examples/*.mp4"),
inputs=video,
label='Example videos (drag them to the input window)',
run_on_click=False,
)
run_button.click(fn=vid2vid,
inputs=inputs,
outputs=result,)
return demo