Spaces:
vztu
/
Runtime error

COVER / app.py
nanushio
- [MINOR] [SOURCE] [UPDATE] 1. update app.py
a135217
raw
history blame
6.98 kB
import gradio as gr
import torch
import argparse
import pickle as pkl
import decord
from decord import VideoReader
import numpy as np
import yaml
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from cover.datasets import UnifiedFrameSampler, spatial_temporal_view_decomposition
from cover.models import COVER
import pandas as pd
mean, std = (
torch.FloatTensor([123.675, 116.28, 103.53]),
torch.FloatTensor([58.395, 57.12, 57.375]),
)
mean_clip, std_clip = (
torch.FloatTensor([122.77, 116.75, 104.09]),
torch.FloatTensor([68.50, 66.63, 70.32])
)
sample_interval = 30
normalization_array = {
"semantic" : [-0.1477,-0.0181],
"technical": [-1.8762, 1.2428],
"aesthetic": [-1.2899, 0.5290],
"overall" : [-3.2538, 1.6728]
}
comparison_array = {
"semantic" : [], # 示例数组
"technical": [],
"aesthetic": [],
"overall" : []
}
def get_sampler_params(video_path):
vr = VideoReader(video_path)
total_frames = len(vr)
clip_len = (total_frames + sample_interval // 2) // sample_interval
if clip_len == 0:
clip_len = 1
t_frag = clip_len
return total_frames, clip_len, t_frag
def fuse_results(results: list):
x = (results[0] + results[1] + results[2])
return {
"semantic" : results[0],
"technical": results[1],
"aesthetic": results[2],
"overall" : x,
}
def normalize_score(score, min_score, max_score):
return (score - min_score) / (max_score - min_score) * 5
def compare_score(score, score_list):
better_than = sum(1 for s in score_list if score > s)
percentage = better_than / len(score_list) * 100
return f"Better than {percentage:.0f}% videos in YT-UGC" if percentage > 50 else f"Worse than {100-percentage:.0f}% videos in YT-UGC"
def create_bar_chart(scores, comparisons):
labels = ['Semantic', 'Aesthetic', 'Technical', 'Overall']
base_colors = ['#d62728', '#ff7f0e', '#1f77b4', '#bcbd22']
fig, ax = plt.subplots(figsize=(10, 5))
for i, (label, score, comparison, base_color) in enumerate(zip(labels, scores, comparisons, base_colors)):
gradient = patches.Rectangle((0, i), 5, 1, color=base_color, alpha=0.5)
ax.add_patch(gradient)
# Add the actual score line
ax.plot([score, score], [i, i+0.9], color='black', linewidth=2)
ax.text(score + 0.1, i + 0.5, f'{score:.1f}', va='center', ha='left', color=base_color)
ax.text(5.1, i + 0.5, comparison, va='center', ha='left', color=base_color)
ax.set_yticks(range(len(labels)))
ax.set_yticklabels(labels)
for tick, color in zip(ax.get_yticklabels(), base_colors):
tick.set_color(color)
ax.set_xticks([0, 1, 2, 3, 4, 5])
ax.set_xticklabels([0, 1, 2, 3, 4, 5])
ax.set_xlim(0, 5)
ax.set_xlabel('Score')
plt.tight_layout()
image_path = "./bar_chart.png"
plt.savefig(image_path)
plt.close()
return image_path
def inference_one_video(input_video):
"""
BASIC SETTINGS
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open("./cover.yml", "r") as f:
opt = yaml.safe_load(f)
dopt = opt["data"]["val-ytugc"]["args"]
temporal_samplers = {}
# auto decision of parameters of sampler
total_frames, clip_len, t_frag = get_sampler_params(input_video)
for stype, sopt in dopt["sample_types"].items():
sopt["clip_len"] = clip_len
sopt["t_frag"] = t_frag
if stype == 'technical' or stype == 'aesthetic':
if total_frames > 1:
sopt["clip_len"] = clip_len * 2
if stype == 'technical':
sopt["aligned"] = sopt["clip_len"]
temporal_samplers[stype] = UnifiedFrameSampler(
sopt["clip_len"] // sopt["t_frag"],
sopt["t_frag"],
sopt["frame_interval"],
sopt["num_clips"],
)
"""
LOAD MODEL
"""
evaluator = COVER(**opt["model"]["args"]).to(device)
state_dict = torch.load(opt["test_load_path"], map_location=device)
# set strict=False here to avoid error of missing
# weight of prompt_learner in clip-iqa+, cross-gate
evaluator.load_state_dict(state_dict['state_dict'], strict=False)
"""
TESTING
"""
views, _ = spatial_temporal_view_decomposition(
input_video, dopt["sample_types"], temporal_samplers
)
for k, v in views.items():
num_clips = dopt["sample_types"][k].get("num_clips", 1)
if k == 'technical' or k == 'aesthetic':
views[k] = (
((v.permute(1, 2, 3, 0) - mean) / std)
.permute(3, 0, 1, 2)
.reshape(v.shape[0], num_clips, -1, *v.shape[2:])
.transpose(0, 1)
.to(device)
)
elif k == 'semantic':
views[k] = (
((v.permute(1, 2, 3, 0) - mean_clip) / std_clip)
.permute(3, 0, 1, 2)
.reshape(v.shape[0], num_clips, -1, *v.shape[2:])
.transpose(0, 1)
.to(device)
)
results = [r.mean().item() for r in evaluator(views)]
pred_score = fuse_results(results)
normalized_scores = [
normalize_score(pred_score["semantic"] , normalization_array["semantic"][0] , normalization_array["semantic"][1] ),
normalize_score(pred_score["technical"], normalization_array["technical"][0], normalization_array["technical"][1]),
normalize_score(pred_score["aesthetic"], normalization_array["aesthetic"][0], normalization_array["aesthetic"][1]),
normalize_score(pred_score["overall"] , normalization_array["overall"][0] , normalization_array["overall"][1])
]
comparison_array["semantic"] = pd.read_csv('./prediction_results/youtube_ugc/smos.csv')['Mos']
comparison_array["technical"] = pd.read_csv('./prediction_results/youtube_ugc/tmos.csv')['Mos']
comparison_array["aesthetic"] = pd.read_csv('./prediction_results/youtube_ugc/amos.csv')['Mos']
comparison_array["overall"] = pd.read_csv('./prediction_results/youtube_ugc/overall.csv')['Mos']
comparisons = [
compare_score(pred_score["semantic"], comparison_array["semantic"]),
compare_score(pred_score["technical"], comparison_array["technical"]),
compare_score(pred_score["aesthetic"], comparison_array["aesthetic"]),
compare_score(pred_score["overall"], comparison_array["overall"])
]
image_path = create_bar_chart(normalized_scores, comparisons)
return image_path
# Define the input and output types for Gradio using the new API
video_input = gr.Video(label="Input Video")
output_image = gr.Image(label="Scores")
# Create the Gradio interface
gradio_app = gr.Interface(fn=inference_one_video, inputs=video_input, outputs=output_image)
if __name__ == "__main__":
gradio_app.launch()