import gradio as gr import torch import argparse import pickle as pkl import decord from decord import VideoReader import numpy as np import yaml import matplotlib.pyplot as plt import matplotlib.patches as patches from cover.datasets import UnifiedFrameSampler, spatial_temporal_view_decomposition from cover.models import COVER import pandas as pd mean, std = ( torch.FloatTensor([123.675, 116.28, 103.53]), torch.FloatTensor([58.395, 57.12, 57.375]), ) mean_clip, std_clip = ( torch.FloatTensor([122.77, 116.75, 104.09]), torch.FloatTensor([68.50, 66.63, 70.32]) ) sample_interval = 30 normalization_array = { "semantic" : [-0.1477,-0.0181], "technical": [-1.8762, 1.2428], "aesthetic": [-1.2899, 0.5290], "overall" : [-3.2538, 1.6728] } comparison_array = { "semantic" : [], # 示例数组 "technical": [], "aesthetic": [], "overall" : [] } def get_sampler_params(video_path): vr = VideoReader(video_path) total_frames = len(vr) clip_len = (total_frames + sample_interval // 2) // sample_interval if clip_len == 0: clip_len = 1 t_frag = clip_len return total_frames, clip_len, t_frag def fuse_results(results: list): x = (results[0] + results[1] + results[2]) return { "semantic" : results[0], "technical": results[1], "aesthetic": results[2], "overall" : x, } def normalize_score(score, min_score, max_score): return (score - min_score) / (max_score - min_score) * 5 def compare_score(score, score_list): better_than = sum(1 for s in score_list if score > s) percentage = better_than / len(score_list) * 100 return f"Better than {percentage:.0f}% videos in YT-UGC" if percentage > 50 else f"Worse than {100-percentage:.0f}% videos in YT-UGC" def create_bar_chart(scores, comparisons): labels = ['Semantic', 'Aesthetic', 'Technical', 'Overall'] base_colors = ['#d62728', '#ff7f0e', '#1f77b4', '#bcbd22'] fig, ax = plt.subplots(figsize=(10, 5)) for i, (label, score, comparison, base_color) in enumerate(zip(labels, scores, comparisons, base_colors)): gradient = patches.Rectangle((0, i), 5, 1, color=base_color, alpha=0.5) ax.add_patch(gradient) # Add the actual score line ax.plot([score, score], [i, i+0.9], color='black', linewidth=2) ax.text(score + 0.1, i + 0.5, f'{score:.1f}', va='center', ha='left', color=base_color) ax.text(5.1, i + 0.5, comparison, va='center', ha='left', color=base_color) ax.set_yticks(range(len(labels))) ax.set_yticklabels(labels) for tick, color in zip(ax.get_yticklabels(), base_colors): tick.set_color(color) ax.set_xticks([0, 1, 2, 3, 4, 5]) ax.set_xticklabels([0, 1, 2, 3, 4, 5]) ax.set_xlim(0, 5) ax.set_xlabel('Score') plt.tight_layout() image_path = "./bar_chart.png" plt.savefig(image_path) plt.close() return image_path def inference_one_video(input_video): """ BASIC SETTINGS """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with open("./cover.yml", "r") as f: opt = yaml.safe_load(f) dopt = opt["data"]["val-ytugc"]["args"] temporal_samplers = {} # auto decision of parameters of sampler total_frames, clip_len, t_frag = get_sampler_params(input_video) for stype, sopt in dopt["sample_types"].items(): sopt["clip_len"] = clip_len sopt["t_frag"] = t_frag if stype == 'technical' or stype == 'aesthetic': if total_frames > 1: sopt["clip_len"] = clip_len * 2 if stype == 'technical': sopt["aligned"] = sopt["clip_len"] temporal_samplers[stype] = UnifiedFrameSampler( sopt["clip_len"] // sopt["t_frag"], sopt["t_frag"], sopt["frame_interval"], sopt["num_clips"], ) """ LOAD MODEL """ evaluator = COVER(**opt["model"]["args"]).to(device) state_dict = torch.load(opt["test_load_path"], map_location=device) # set strict=False here to avoid error of missing # weight of prompt_learner in clip-iqa+, cross-gate evaluator.load_state_dict(state_dict['state_dict'], strict=False) """ TESTING """ views, _ = spatial_temporal_view_decomposition( input_video, dopt["sample_types"], temporal_samplers ) for k, v in views.items(): num_clips = dopt["sample_types"][k].get("num_clips", 1) if k == 'technical' or k == 'aesthetic': views[k] = ( ((v.permute(1, 2, 3, 0) - mean) / std) .permute(3, 0, 1, 2) .reshape(v.shape[0], num_clips, -1, *v.shape[2:]) .transpose(0, 1) .to(device) ) elif k == 'semantic': views[k] = ( ((v.permute(1, 2, 3, 0) - mean_clip) / std_clip) .permute(3, 0, 1, 2) .reshape(v.shape[0], num_clips, -1, *v.shape[2:]) .transpose(0, 1) .to(device) ) results = [r.mean().item() for r in evaluator(views)] pred_score = fuse_results(results) normalized_scores = [ normalize_score(pred_score["semantic"] , normalization_array["semantic"][0] , normalization_array["semantic"][1] ), normalize_score(pred_score["technical"], normalization_array["technical"][0], normalization_array["technical"][1]), normalize_score(pred_score["aesthetic"], normalization_array["aesthetic"][0], normalization_array["aesthetic"][1]), normalize_score(pred_score["overall"] , normalization_array["overall"][0] , normalization_array["overall"][1]) ] comparison_array["semantic"] = pd.read_csv('./prediction_results/youtube_ugc/smos.csv')['Mos'] comparison_array["technical"] = pd.read_csv('./prediction_results/youtube_ugc/tmos.csv')['Mos'] comparison_array["aesthetic"] = pd.read_csv('./prediction_results/youtube_ugc/amos.csv')['Mos'] comparison_array["overall"] = pd.read_csv('./prediction_results/youtube_ugc/overall.csv')['Mos'] comparisons = [ compare_score(pred_score["semantic"], comparison_array["semantic"]), compare_score(pred_score["technical"], comparison_array["technical"]), compare_score(pred_score["aesthetic"], comparison_array["aesthetic"]), compare_score(pred_score["overall"], comparison_array["overall"]) ] image_path = create_bar_chart(normalized_scores, comparisons) return image_path # Define the input and output types for Gradio using the new API video_input = gr.Video(label="Input Video") output_image = gr.Image(label="Scores") # Create the Gradio interface gradio_app = gr.Interface(fn=inference_one_video, inputs=video_input, outputs=output_image) if __name__ == "__main__": gradio_app.launch()