import cv2 import argparse from basicsr.test_img import image_sr from os import path as osp import os import shutil from PIL import Image import re import imageio.v2 as imageio import threading from concurrent.futures import ThreadPoolExecutor import time def replace_filename(original_path, suffix): directory = os.path.dirname(original_path) old_filename = os.path.basename(original_path) name_part, file_extension = os.path.splitext(old_filename) new_filename = f"{name_part}{suffix}{file_extension}" new_path = os.path.join(directory, new_filename) return new_path def create_temp_folder(folder_path): if os.path.exists(folder_path): shutil.rmtree(folder_path) os.makedirs(folder_path) def delete_temp_folder(folder_path): shutil.rmtree(folder_path) def extract_number(filename): s = re.findall(r'\d+', filename) return int(s[0]) if s else -1 def bicubic_upsample_opencv(input_image_path, output_image_path, scale_factor): img = cv2.imread(input_image_path) original_height, original_width = img.shape[:2] new_width = int(original_width * scale_factor) new_height = int(original_height * scale_factor) upsampled_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_CUBIC) cv2.imwrite(output_image_path, upsampled_img) def process_frame(frame_count, frame, temp_LR_folder_path, temp_HR_folder_path, SR): frame_path = os.path.join(temp_LR_folder_path, f"frame_{frame_count}{SR}.png") cv2.imwrite(frame_path, frame) HR_frame_path = os.path.join(temp_HR_folder_path, f"frame_{frame_count}.png") if SR == 'x4': bicubic_upsample_opencv(frame_path, HR_frame_path, 4) elif SR == 'x2': bicubic_upsample_opencv(frame_path, HR_frame_path, 2) def video_sr(args): file_name = os.path.basename(args.input_dir) video_output_path = os.path.join(args.output_dir,file_name) if args.SR == 'x4': temp_LR_folder_path = os.path.join(args.output_dir, f'temp_LR/X4') video_output_path = replace_filename(video_output_path, '_x4') result_temp = osp.join(args.root_path, f'results/test_RGT_x4/visualization/Set5') if args.SR == 'x2': temp_LR_folder_path = os.path.join(args.output_dir, f'temp_LR/X2') video_output_path = replace_filename(video_output_path, '_x2') result_temp = osp.join(args.root_path, f'results/test_RGT_x2/visualization/Set5') temp_HR_folder_path = os.path.join(args.output_dir, f'temp_HR') # create_temp_folder(result_temp) create_temp_folder(temp_LR_folder_path) create_temp_folder(temp_HR_folder_path) cap = cv2.VideoCapture(args.input_dir) if not cap.isOpened(): print("Error opening video file.") return t1 = time.time() frame_count = 0 frames_to_process = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break frames_to_process.append((frame_count, frame)) frame_count += 1 with ThreadPoolExecutor(max_workers = args.mul_numwork) as executor: for frame_count, frame in frames_to_process: executor.submit(process_frame, frame_count, frame, temp_LR_folder_path, temp_HR_folder_path, args.SR) print("total frames:",frame_count) print("fps :",cap.get(cv2.CAP_PROP_FPS)) t2 = time.time() print('mul threads: ',t2 - t1,'s') # progress all frames in video image_sr(args) t3 = time.time() print('image super resolution: ',t3 - t2,'s') # recover video form all frames frame_files = sorted(os.listdir(result_temp), key=extract_number) video_frames = [imageio.imread(os.path.join(result_temp, frame_file)) for frame_file in frame_files] fps = cap.get(cv2.CAP_PROP_FPS) imageio.mimwrite(video_output_path, video_frames, fps=fps, quality=9) t4 = time.time() print('tranformer frames to video: ',t4 - t3,'s') # release all resources cap.release() delete_temp_folder(os.path.dirname(temp_LR_folder_path)) delete_temp_folder(temp_HR_folder_path) delete_temp_folder(os.path.join(args.root_path, f'results')) t5 = time.time() print('delete time: ',t5 - t4,'s') if __name__ == "__main__": parser = argparse.ArgumentParser(description="RGT for Video Super-Resolution") # make sure you SR is match with the ckpt_path parser.add_argument("--SR", type=str, choices=['x2', 'x4'], default='x4', help='image resolution') parser.add_argument("--ckpt_path", type=str, default = "/remote-home/lzy/RGT/experiments/pretrained_models/RGT_x4.pth") parser.add_argument("--root_path", type=str, default = "/remote-home/lzy/RGT") parser.add_argument("--input_dir", type=str, default= "/remote-home/lzy/RGT/datasets/video/video_test1.mp4") parser.add_argument("--output_dir", type=str, default= "/remote-home/lzy/RGT/datasets/video_output") parser.add_argument("--mul_numwork", type=int, default = 16, help ='max_workers to execute Multi') parser.add_argument("--use_chop", type= bool, default = True, help ='use_chop: True # True to save memory, if img too large') args = parser.parse_args() video_sr(args)