| import os | |
| import math | |
| import utility | |
| from data import common | |
| import torch | |
| import cv2 | |
| from tqdm import tqdm | |
| class VideoTester(): | |
| def __init__(self, args, my_model, ckp): | |
| self.args = args | |
| self.scale = args.scale | |
| self.ckp = ckp | |
| self.model = my_model | |
| self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) | |
| def test(self): | |
| torch.set_grad_enabled(False) | |
| self.ckp.write_log('\nEvaluation on video:') | |
| self.model.eval() | |
| timer_test = utility.timer() | |
| for idx_scale, scale in enumerate(self.scale): | |
| vidcap = cv2.VideoCapture(self.args.dir_demo) | |
| total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| vidwri = cv2.VideoWriter( | |
| self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)), | |
| cv2.VideoWriter_fourcc(*'XVID'), | |
| vidcap.get(cv2.CAP_PROP_FPS), | |
| ( | |
| int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)), | |
| int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| ) | |
| ) | |
| tqdm_test = tqdm(range(total_frames), ncols=80) | |
| for _ in tqdm_test: | |
| success, lr = vidcap.read() | |
| if not success: break | |
| lr, = common.set_channel(lr, n_channels=self.args.n_colors) | |
| lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) | |
| lr, = self.prepare(lr.unsqueeze(0)) | |
| sr = self.model(lr, idx_scale) | |
| sr = utility.quantize(sr, self.args.rgb_range).squeeze(0) | |
| normalized = sr * 255 / self.args.rgb_range | |
| ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() | |
| vidwri.write(ndarr) | |
| vidcap.release() | |
| vidwri.release() | |
| self.ckp.write_log( | |
| 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True | |
| ) | |
| torch.set_grad_enabled(True) | |
| def prepare(self, *args): | |
| device = torch.device('cpu' if self.args.cpu else 'cuda') | |
| def _prepare(tensor): | |
| if self.args.precision == 'half': tensor = tensor.half() | |
| return tensor.to(device) | |
| return [_prepare(a) for a in args] | |