|
import os
|
|
import math
|
|
|
|
import utility
|
|
from data import common
|
|
|
|
import torch
|
|
import cv2
|
|
|
|
from tqdm import tqdm
|
|
|
|
class VideoTester():
|
|
def __init__(self, args, my_model, ckp):
|
|
self.args = args
|
|
self.scale = args.scale
|
|
|
|
self.ckp = ckp
|
|
self.model = my_model
|
|
|
|
self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
|
|
|
|
def test(self):
|
|
torch.set_grad_enabled(False)
|
|
|
|
self.ckp.write_log('\nEvaluation on video:')
|
|
self.model.eval()
|
|
|
|
timer_test = utility.timer()
|
|
for idx_scale, scale in enumerate(self.scale):
|
|
vidcap = cv2.VideoCapture(self.args.dir_demo)
|
|
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
vidwri = cv2.VideoWriter(
|
|
self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)),
|
|
cv2.VideoWriter_fourcc(*'XVID'),
|
|
vidcap.get(cv2.CAP_PROP_FPS),
|
|
(
|
|
int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)),
|
|
int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
)
|
|
)
|
|
|
|
tqdm_test = tqdm(range(total_frames), ncols=80)
|
|
for _ in tqdm_test:
|
|
success, lr = vidcap.read()
|
|
if not success: break
|
|
|
|
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
|
|
lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
|
|
lr, = self.prepare(lr.unsqueeze(0))
|
|
sr = self.model(lr, idx_scale)
|
|
sr = utility.quantize(sr, self.args.rgb_range).squeeze(0)
|
|
|
|
normalized = sr * 255 / self.args.rgb_range
|
|
ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
|
|
vidwri.write(ndarr)
|
|
|
|
vidcap.release()
|
|
vidwri.release()
|
|
|
|
self.ckp.write_log(
|
|
'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
|
|
)
|
|
torch.set_grad_enabled(True)
|
|
|
|
def prepare(self, *args):
|
|
device = torch.device('cpu' if self.args.cpu else 'cuda')
|
|
def _prepare(tensor):
|
|
if self.args.precision == 'half': tensor = tensor.half()
|
|
return tensor.to(device)
|
|
|
|
return [_prepare(a) for a in args]
|
|
|
|
|