File size: 2,923 Bytes
6ab99a7 4de6d7c 6ab99a7 289aa5c c45a2ea 289aa5c 6ab99a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import gradio as gr
import torch
import argparse
import pickle as pkl
import decord
from decord import VideoReader
import numpy as np
import yaml
from cover.datasets import UnifiedFrameSampler, spatial_temporal_view_decomposition
from cover.models import COVER
mean, std = (
torch.FloatTensor([123.675, 116.28, 103.53]),
torch.FloatTensor([58.395, 57.12, 57.375]),
)
mean_clip, std_clip = (
torch.FloatTensor([122.77, 116.75, 104.09]),
torch.FloatTensor([68.50, 66.63, 70.32])
)
def fuse_results(results: list):
x = (results[0] + results[1] + results[2])
return {
"semantic" : results[0],
"technical": results[1],
"aesthetic": results[2],
"overall" : x,
}
def inference_one_video(input_video):
"""
BASIC SETTINGS
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open("./cover.yml", "r") as f:
opt = yaml.safe_load(f)
dopt = opt["data"]["val-ytugc"]["args"]
temporal_samplers = {}
for stype, sopt in dopt["sample_types"].items():
temporal_samplers[stype] = UnifiedFrameSampler(
sopt["clip_len"] // sopt["t_frag"],
sopt["t_frag"],
sopt["frame_interval"],
sopt["num_clips"],
)
"""
LOAD MODEL
"""
evaluator = COVER(**opt["model"]["args"]).to(device)
state_dict = torch.load(opt["test_load_path"], map_location=device)
# set strict=False here to avoid error of missing
# weight of prompt_learner in clip-iqa+, cross-gate
evaluator.load_state_dict(state_dict['state_dict'], strict=False)
"""
TESTING
"""
views, _ = spatial_temporal_view_decomposition(
input_video, dopt["sample_types"], temporal_samplers
)
for k, v in views.items():
num_clips = dopt["sample_types"][k].get("num_clips", 1)
if k == 'technical' or k == 'aesthetic':
views[k] = (
((v.permute(1, 2, 3, 0) - mean) / std)
.permute(3, 0, 1, 2)
.reshape(v.shape[0], num_clips, -1, *v.shape[2:])
.transpose(0, 1)
.to(device)
)
elif k == 'semantic':
views[k] = (
((v.permute(1, 2, 3, 0) - mean_clip) / std_clip)
.permute(3, 0, 1, 2)
.reshape(v.shape[0], num_clips, -1, *v.shape[2:])
.transpose(0, 1)
.to(device)
)
results = [r.mean().item() for r in evaluator(views)]
pred_score = fuse_results(results)
return pred_score
# Define the input and output types for Gradio using the new API
video_input = gr.Video(label="Input Video")
output_label = gr.JSON(label="Scores")
# Create the Gradio interface
gradio_app = gr.Interface(fn=inference_one_video, inputs=video_input, outputs=output_label)
if __name__ == "__main__":
gradio_app.launch() |