|
import gradio as gr |
|
|
|
import torch |
|
|
|
import argparse |
|
import pickle as pkl |
|
|
|
import decord |
|
from decord import VideoReader |
|
import numpy as np |
|
import yaml |
|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as patches |
|
|
|
from cover.datasets import UnifiedFrameSampler, spatial_temporal_view_decomposition |
|
from cover.models import COVER |
|
|
|
import pandas as pd |
|
|
|
mean, std = ( |
|
torch.FloatTensor([123.675, 116.28, 103.53]), |
|
torch.FloatTensor([58.395, 57.12, 57.375]), |
|
) |
|
|
|
mean_clip, std_clip = ( |
|
torch.FloatTensor([122.77, 116.75, 104.09]), |
|
torch.FloatTensor([68.50, 66.63, 70.32]) |
|
) |
|
|
|
sample_interval = 30 |
|
|
|
normalization_array = { |
|
"semantic" : [-0.1477,-0.0181], |
|
"technical": [-1.8762, 1.2428], |
|
"aesthetic": [-1.2899, 0.5290], |
|
"overall" : [-3.2538, 1.6728] |
|
} |
|
|
|
comparison_array = { |
|
"semantic" : [], |
|
"technical": [], |
|
"aesthetic": [], |
|
"overall" : [] |
|
} |
|
|
|
def get_sampler_params(video_path): |
|
vr = VideoReader(video_path) |
|
total_frames = len(vr) |
|
clip_len = (total_frames + sample_interval // 2) // sample_interval |
|
if clip_len == 0: |
|
clip_len = 1 |
|
t_frag = clip_len |
|
|
|
return total_frames, clip_len, t_frag |
|
|
|
def fuse_results(results: list): |
|
x = (results[0] + results[1] + results[2]) |
|
return { |
|
"semantic" : results[0], |
|
"technical": results[1], |
|
"aesthetic": results[2], |
|
"overall" : x, |
|
} |
|
|
|
def normalize_score(score, min_score, max_score): |
|
return (score - min_score) / (max_score - min_score) * 5 |
|
|
|
def compare_score(score, score_list): |
|
better_than = sum(1 for s in score_list if score > s) |
|
percentage = better_than / len(score_list) * 100 |
|
return f"Better than {percentage:.0f}% videos in YT-UGC" if percentage > 50 else f"Worse than {100-percentage:.0f}% videos in YT-UGC" |
|
|
|
def create_bar_chart(scores, comparisons): |
|
labels = ['Semantic', 'Aesthetic', 'Technical', 'Overall'] |
|
base_colors = ['#d62728', '#ff7f0e', '#1f77b4', '#bcbd22'] |
|
|
|
fig, ax = plt.subplots(figsize=(10, 5)) |
|
|
|
for i, (label, score, comparison, base_color) in enumerate(zip(labels, scores, comparisons, base_colors)): |
|
gradient = patches.Rectangle((0, i), 5, 1, color=base_color, alpha=0.5) |
|
ax.add_patch(gradient) |
|
|
|
|
|
ax.plot([score, score], [i, i+0.9], color='black', linewidth=2) |
|
|
|
ax.text(score + 0.1, i + 0.5, f'{score:.1f}', va='center', ha='left', color=base_color) |
|
ax.text(5.1, i + 0.5, comparison, va='center', ha='left', color=base_color) |
|
|
|
ax.set_yticks(range(len(labels))) |
|
ax.set_yticklabels(labels) |
|
for tick, color in zip(ax.get_yticklabels(), base_colors): |
|
tick.set_color(color) |
|
ax.set_xticks([0, 1, 2, 3, 4, 5]) |
|
ax.set_xticklabels([0, 1, 2, 3, 4, 5]) |
|
ax.set_xlim(0, 5) |
|
ax.set_xlabel('Score') |
|
|
|
plt.tight_layout() |
|
image_path = "./bar_chart.png" |
|
plt.savefig(image_path) |
|
plt.close() |
|
|
|
return image_path |
|
|
|
def inference_one_video(input_video): |
|
""" |
|
BASIC SETTINGS |
|
""" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
with open("./cover.yml", "r") as f: |
|
opt = yaml.safe_load(f) |
|
|
|
dopt = opt["data"]["val-ytugc"]["args"] |
|
temporal_samplers = {} |
|
|
|
|
|
total_frames, clip_len, t_frag = get_sampler_params(input_video) |
|
|
|
for stype, sopt in dopt["sample_types"].items(): |
|
sopt["clip_len"] = clip_len |
|
sopt["t_frag"] = t_frag |
|
if stype == 'technical' or stype == 'aesthetic': |
|
if total_frames > 1: |
|
sopt["clip_len"] = clip_len * 2 |
|
if stype == 'technical': |
|
sopt["aligned"] = sopt["clip_len"] |
|
temporal_samplers[stype] = UnifiedFrameSampler( |
|
sopt["clip_len"] // sopt["t_frag"], |
|
sopt["t_frag"], |
|
sopt["frame_interval"], |
|
sopt["num_clips"], |
|
) |
|
|
|
""" |
|
LOAD MODEL |
|
""" |
|
evaluator = COVER(**opt["model"]["args"]).to(device) |
|
state_dict = torch.load(opt["test_load_path"], map_location=device) |
|
|
|
|
|
|
|
evaluator.load_state_dict(state_dict['state_dict'], strict=False) |
|
|
|
""" |
|
TESTING |
|
""" |
|
views, _ = spatial_temporal_view_decomposition( |
|
input_video, dopt["sample_types"], temporal_samplers |
|
) |
|
|
|
for k, v in views.items(): |
|
num_clips = dopt["sample_types"][k].get("num_clips", 1) |
|
if k == 'technical' or k == 'aesthetic': |
|
views[k] = ( |
|
((v.permute(1, 2, 3, 0) - mean) / std) |
|
.permute(3, 0, 1, 2) |
|
.reshape(v.shape[0], num_clips, -1, *v.shape[2:]) |
|
.transpose(0, 1) |
|
.to(device) |
|
) |
|
elif k == 'semantic': |
|
views[k] = ( |
|
((v.permute(1, 2, 3, 0) - mean_clip) / std_clip) |
|
.permute(3, 0, 1, 2) |
|
.reshape(v.shape[0], num_clips, -1, *v.shape[2:]) |
|
.transpose(0, 1) |
|
.to(device) |
|
) |
|
|
|
results = [r.mean().item() for r in evaluator(views)] |
|
pred_score = fuse_results(results) |
|
|
|
normalized_scores = [ |
|
normalize_score(pred_score["semantic"] , normalization_array["semantic"][0] , normalization_array["semantic"][1] ), |
|
normalize_score(pred_score["technical"], normalization_array["technical"][0], normalization_array["technical"][1]), |
|
normalize_score(pred_score["aesthetic"], normalization_array["aesthetic"][0], normalization_array["aesthetic"][1]), |
|
normalize_score(pred_score["overall"] , normalization_array["overall"][0] , normalization_array["overall"][1]) |
|
] |
|
|
|
comparison_array["semantic"] = pd.read_csv('./prediction_results/youtube_ugc/smos.csv')['Mos'] |
|
comparison_array["technical"] = pd.read_csv('./prediction_results/youtube_ugc/tmos.csv')['Mos'] |
|
comparison_array["aesthetic"] = pd.read_csv('./prediction_results/youtube_ugc/amos.csv')['Mos'] |
|
comparison_array["overall"] = pd.read_csv('./prediction_results/youtube_ugc/overall.csv')['Mos'] |
|
|
|
comparisons = [ |
|
compare_score(pred_score["semantic"], comparison_array["semantic"]), |
|
compare_score(pred_score["technical"], comparison_array["technical"]), |
|
compare_score(pred_score["aesthetic"], comparison_array["aesthetic"]), |
|
compare_score(pred_score["overall"], comparison_array["overall"]) |
|
] |
|
|
|
image_path = create_bar_chart(normalized_scores, comparisons) |
|
return image_path |
|
|
|
|
|
video_input = gr.Video(label="Input Video") |
|
output_image = gr.Image(label="Scores") |
|
|
|
|
|
gradio_app = gr.Interface(fn=inference_one_video, inputs=video_input, outputs=output_image) |
|
|
|
if __name__ == "__main__": |
|
gradio_app.launch() |