Spaces:
Runtime error
Runtime error
File size: 1,922 Bytes
ba9de54 0fc4944 ba9de54 9108e1f f1f6dd4 ba9de54 9108e1f ba9de54 9108e1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import os
import torch
from torchvision.io import read_video, read_video_timestamps
from videogpt import download, load_vqvae
from videogpt.data import preprocess
import imageio
import gradio as gr
from moviepy.editor import *
device = torch.device('cpu')
vqvae = load_vqvae('kinetics_stride2x4x4', device=device).to(device)
resolution, sequence_length = vqvae.args.resolution, 16
def vgpt(invid):
try:
os.remove("output.mp4")
except FileNotFoundError:
pass
clip = VideoFileClip(invid)
rate = clip.fps
pts = read_video_timestamps(invid, pts_unit='sec')[0]
video = read_video(invid, pts_unit='sec', start_pts=pts[0], end_pts=pts[sequence_length - 1])[0]
video = preprocess(video, resolution, sequence_length).unsqueeze(0).to(device)
with torch.no_grad():
encodings = vqvae.encode(video)
video_recon = vqvae.decode(encodings)
video_recon = torch.clamp(video_recon, -0.5, 0.5)
videos = video_recon[0].permute(1, 2, 3, 0) # CTHW -> THWC
videos = ((videos + 0.5) * 255).cpu().numpy().astype('uint8')
imageio.mimwrite('output.mp4', videos, fps=int(rate))
return './output.mp4'
inputs = gr.inputs.Video(label="Input Video")
outputs = gr.outputs.Video(label="Output Video")
title = "VideoGPT"
description = "Gradio demo for VideoGPT: Video Generation using VQ-VAE and Transformers for video reconstruction. To use it, simply upload your video, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.10157' target='_blank'>VideoGPT: Video Generation using VQ-VAE and Transformers</a> | <a href='https://github.com/wilson1yan/VideoGPT' target='_blank'>Github Repo</a></p>"
examples = [
['bear.mp4'],
['breakdance.mp4']
]
gr.Interface(vgpt, inputs, outputs, title=title, description=description, article=article, examples=examples,enable_queue=True).launch(debug=True) |