File size: 3,258 Bytes
feb2918 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import torch
import argparse
import pickle as pkl
import decord
from decord import VideoReader
import numpy as np
import yaml
from cover.datasets import UnifiedFrameSampler, spatial_temporal_view_decomposition
from cover.models import COVER
mean, std = (
torch.FloatTensor([123.675, 116.28, 103.53]),
torch.FloatTensor([58.395, 57.12, 57.375]),
)
mean_clip, std_clip = (
torch.FloatTensor([122.77, 116.75, 104.09]),
torch.FloatTensor([68.50, 66.63, 70.32])
)
def fuse_results(results: list):
x = (results[0] + results[1] + results[2])
return {
"semantic" : results[0],
"technical": results[1],
"aesthetic": results[2],
"overall" : x,
}
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--opt" , type=str, default="./cover.yml", help="the option file")
parser.add_argument("--video_path", type=str, default="./demo/video_1.mp4" , help='output file to store predict mos value')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
"""
BASIC SETTINGS
"""
torch.cuda.current_device()
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open(args.opt, "r") as f:
opt = yaml.safe_load(f)
dopt = opt["data"]["val-ytugc"]["args"]
temporal_samplers = {}
for stype, sopt in dopt["sample_types"].items():
temporal_samplers[stype] = UnifiedFrameSampler(
sopt["clip_len"] // sopt["t_frag"],
sopt["t_frag"],
sopt["frame_interval"],
sopt["num_clips"],
)
"""
LOAD MODEL
"""
evaluator = COVER(**opt["model"]["args"]).to(device)
state_dict = torch.load(opt["test_load_path"], map_location=device)
# set strict=False here to avoid error of missing
# weight of prompt_learner in clip-iqa+, cross-gate
evaluator.load_state_dict(state_dict['state_dict'], strict=False)
"""
TESTING
"""
views, _ = spatial_temporal_view_decomposition(
args.video_path, dopt["sample_types"], temporal_samplers
)
for k, v in views.items():
num_clips = dopt["sample_types"][k].get("num_clips", 1)
if k == 'technical' or k == 'aesthetic':
views[k] = (
((v.permute(1, 2, 3, 0) - mean) / std)
.permute(3, 0, 1, 2)
.reshape(v.shape[0], num_clips, -1, *v.shape[2:])
.transpose(0, 1)
.to(device)
)
elif k == 'semantic':
views[k] = (
((v.permute(1, 2, 3, 0) - mean_clip) / std_clip)
.permute(3, 0, 1, 2)
.reshape(v.shape[0], num_clips, -1, *v.shape[2:])
.transpose(0, 1)
.to(device)
)
results = [r.mean().item() for r in evaluator(views)]
pred_score = fuse_results(results)
print(f"path, semantic score, technical score, aesthetic score, overall/final score")
print(f'{args.video_path.split("/")[-1]},{pred_score["semantic"]:4f},{pred_score["technical"]:4f},{pred_score["aesthetic"]:4f},{pred_score["overall"]:4f}')
|