|
import gradio as gr |
|
|
|
import torch |
|
|
|
import argparse |
|
import pickle as pkl |
|
|
|
import decord |
|
from decord import VideoReader |
|
import numpy as np |
|
import yaml |
|
|
|
from cover.datasets import UnifiedFrameSampler, spatial_temporal_view_decomposition |
|
from cover.models import COVER |
|
|
|
mean, std = ( |
|
torch.FloatTensor([123.675, 116.28, 103.53]), |
|
torch.FloatTensor([58.395, 57.12, 57.375]), |
|
) |
|
|
|
mean_clip, std_clip = ( |
|
torch.FloatTensor([122.77, 116.75, 104.09]), |
|
torch.FloatTensor([68.50, 66.63, 70.32]) |
|
) |
|
|
|
sample_interval = 30 |
|
|
|
def get_sampler_params(video_path): |
|
vr = VideoReader(video_path) |
|
total_frames = len(vr) |
|
clip_len = (total_frames + sample_interval // 2) // sample_interval |
|
if clip_len == 0: |
|
clip_len = 1 |
|
t_frag = clip_len |
|
|
|
return total_frames, clip_len, t_frag |
|
|
|
def fuse_results(results: list): |
|
x = (results[0] + results[1] + results[2]) |
|
return { |
|
"semantic" : results[0], |
|
"technical": results[1], |
|
"aesthetic": results[2], |
|
"overall" : x, |
|
} |
|
|
|
def inference_one_video(input_video): |
|
""" |
|
BASIC SETTINGS |
|
""" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
with open("./cover.yml", "r") as f: |
|
opt = yaml.safe_load(f) |
|
|
|
dopt = opt["data"]["val-ytugc"]["args"] |
|
temporal_samplers = {} |
|
|
|
|
|
total_frames, clip_len, t_frag = get_sampler_params(input_video) |
|
|
|
for stype, sopt in dopt["sample_types"].items(): |
|
sopt["clip_len"] = clip_len |
|
sopt["t_frag"] = t_frag |
|
if stype == 'technical' or stype == 'aesthetic': |
|
if total_frames > 1: |
|
sopt["clip_len"] = clip_len * 2 |
|
if stype == 'technical': |
|
sopt["aligned"] = sopt["clip_len"] |
|
print(sopt["clip_len"], sopt["t_frag"]) |
|
temporal_samplers[stype] = UnifiedFrameSampler( |
|
sopt["clip_len"] // sopt["t_frag"], |
|
sopt["t_frag"], |
|
sopt["frame_interval"], |
|
sopt["num_clips"], |
|
) |
|
|
|
""" |
|
LOAD MODEL |
|
""" |
|
evaluator = COVER(**opt["model"]["args"]).to(device) |
|
state_dict = torch.load(opt["test_load_path"], map_location=device) |
|
|
|
|
|
|
|
evaluator.load_state_dict(state_dict['state_dict'], strict=False) |
|
|
|
""" |
|
TESTING |
|
""" |
|
views, _ = spatial_temporal_view_decomposition( |
|
input_video, dopt["sample_types"], temporal_samplers |
|
) |
|
|
|
for k, v in views.items(): |
|
num_clips = dopt["sample_types"][k].get("num_clips", 1) |
|
if k == 'technical' or k == 'aesthetic': |
|
views[k] = ( |
|
((v.permute(1, 2, 3, 0) - mean) / std) |
|
.permute(3, 0, 1, 2) |
|
.reshape(v.shape[0], num_clips, -1, *v.shape[2:]) |
|
.transpose(0, 1) |
|
.to(device) |
|
) |
|
elif k == 'semantic': |
|
views[k] = ( |
|
((v.permute(1, 2, 3, 0) - mean_clip) / std_clip) |
|
.permute(3, 0, 1, 2) |
|
.reshape(v.shape[0], num_clips, -1, *v.shape[2:]) |
|
.transpose(0, 1) |
|
.to(device) |
|
) |
|
|
|
results = [r.mean().item() for r in evaluator(views)] |
|
pred_score = fuse_results(results) |
|
return pred_score |
|
|
|
|
|
video_input = gr.Video(label="Input Video") |
|
output_label = gr.JSON(label="Scores") |
|
|
|
|
|
gradio_app = gr.Interface(fn=inference_one_video, inputs=video_input, outputs=output_label) |
|
|
|
if __name__ == "__main__": |
|
gradio_app.launch() |