|
import os |
|
import sys |
|
sys.path.append('.') |
|
import cv2 |
|
import math |
|
import torch |
|
import argparse |
|
import numpy as np |
|
from torch.nn import functional as F |
|
from model.pytorch_msssim import ssim_matlab |
|
from model.RIFE import Model |
|
from skimage.color import rgb2yuv, yuv2rgb |
|
from yuv_frame_io import YUV_Read,YUV_Write |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
model = Model() |
|
model.load_model('train_log') |
|
model.eval() |
|
model.device() |
|
|
|
name_list = [ |
|
('HD_dataset/HD720p_GT/parkrun_1280x720_50.yuv', 720, 1280), |
|
('HD_dataset/HD720p_GT/shields_1280x720_60.yuv', 720, 1280), |
|
('HD_dataset/HD720p_GT/stockholm_1280x720_60.yuv', 720, 1280), |
|
('HD_dataset/HD1080p_GT/BlueSky.yuv', 1080, 1920), |
|
('HD_dataset/HD1080p_GT/Kimono1_1920x1080_24.yuv', 1080, 1920), |
|
('HD_dataset/HD1080p_GT/ParkScene_1920x1080_24.yuv', 1080, 1920), |
|
('HD_dataset/HD1080p_GT/sunflower_1080p25.yuv', 1080, 1920), |
|
('HD_dataset/HD544p_GT/Sintel_Alley2_1280x544.yuv', 544, 1280), |
|
('HD_dataset/HD544p_GT/Sintel_Market5_1280x544.yuv', 544, 1280), |
|
('HD_dataset/HD544p_GT/Sintel_Temple1_1280x544.yuv', 544, 1280), |
|
('HD_dataset/HD544p_GT/Sintel_Temple2_1280x544.yuv', 544, 1280), |
|
] |
|
tot = 0. |
|
for data in name_list: |
|
psnr_list = [] |
|
name = data[0] |
|
h = data[1] |
|
w = data[2] |
|
if 'yuv' in name: |
|
Reader = YUV_Read(name, h, w, toRGB=True) |
|
else: |
|
Reader = cv2.VideoCapture(name) |
|
_, lastframe = Reader.read() |
|
|
|
|
|
for index in range(0, 100, 2): |
|
if 'yuv' in name: |
|
IMAGE1, success1 = Reader.read(index) |
|
gt, _ = Reader.read(index + 1) |
|
IMAGE2, success2 = Reader.read(index + 2) |
|
if not success2: |
|
break |
|
else: |
|
success1, gt = Reader.read() |
|
success2, frame = Reader.read() |
|
IMAGE1 = lastframe |
|
IMAGE2 = frame |
|
lastframe = frame |
|
if not success2: |
|
break |
|
I0 = torch.from_numpy(np.transpose(IMAGE1, (2,0,1)).astype("float32") / 255.).cuda().unsqueeze(0) |
|
I1 = torch.from_numpy(np.transpose(IMAGE2, (2,0,1)).astype("float32") / 255.).cuda().unsqueeze(0) |
|
|
|
if h == 720: |
|
pad = 24 |
|
elif h == 1080: |
|
pad = 4 |
|
else: |
|
pad = 16 |
|
pader = torch.nn.ReplicationPad2d([0, 0, pad, pad]) |
|
I0 = pader(I0) |
|
I1 = pader(I1) |
|
with torch.no_grad(): |
|
pred = model.inference(I0, I1) |
|
pred = pred[:, :, pad: -pad] |
|
out = (np.round(pred[0].detach().cpu().numpy().transpose(1, 2, 0) * 255)).astype('uint8') |
|
|
|
if 'yuv' in name: |
|
diff_rgb = 128.0 + rgb2yuv(gt / 255.)[:, :, 0] * 255 - rgb2yuv(out / 255.)[:, :, 0] * 255 |
|
mse = np.mean((diff_rgb - 128.0) ** 2) |
|
PIXEL_MAX = 255.0 |
|
psnr = 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) |
|
else: |
|
psnr = skim.compare_psnr(gt, out) |
|
psnr_list.append(psnr) |
|
print(np.mean(psnr_list)) |
|
tot += np.mean(psnr_list) |
|
print('avg psnr', tot / len(name_list)) |
|
|