fffiloni's picture
Upload 244 files
b3f324b verified
raw
history blame contribute delete
No virus
5.23 kB
import cv2
import argparse
from basicsr.test_img import image_sr
from os import path as osp
import os
import shutil
from PIL import Image
import re
import imageio.v2 as imageio
import threading
from concurrent.futures import ThreadPoolExecutor
import time
def replace_filename(original_path, suffix):
directory = os.path.dirname(original_path)
old_filename = os.path.basename(original_path)
name_part, file_extension = os.path.splitext(old_filename)
new_filename = f"{name_part}{suffix}{file_extension}"
new_path = os.path.join(directory, new_filename)
return new_path
def create_temp_folder(folder_path):
if os.path.exists(folder_path):
shutil.rmtree(folder_path)
os.makedirs(folder_path)
def delete_temp_folder(folder_path):
shutil.rmtree(folder_path)
def extract_number(filename):
s = re.findall(r'\d+', filename)
return int(s[0]) if s else -1
def bicubic_upsample_opencv(input_image_path, output_image_path, scale_factor):
img = cv2.imread(input_image_path)
original_height, original_width = img.shape[:2]
new_width = int(original_width * scale_factor)
new_height = int(original_height * scale_factor)
upsampled_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
cv2.imwrite(output_image_path, upsampled_img)
def process_frame(frame_count, frame, temp_LR_folder_path, temp_HR_folder_path, SR):
frame_path = os.path.join(temp_LR_folder_path, f"frame_{frame_count}{SR}.png")
cv2.imwrite(frame_path, frame)
HR_frame_path = os.path.join(temp_HR_folder_path, f"frame_{frame_count}.png")
if SR == 'x4':
bicubic_upsample_opencv(frame_path, HR_frame_path, 4)
elif SR == 'x2':
bicubic_upsample_opencv(frame_path, HR_frame_path, 2)
def video_sr(args):
file_name = os.path.basename(args.input_dir)
video_output_path = os.path.join(args.output_dir,file_name)
if args.SR == 'x4':
temp_LR_folder_path = os.path.join(args.output_dir, f'temp_LR/X4')
video_output_path = replace_filename(video_output_path, '_x4')
result_temp = osp.join(args.root_path, f'results/test_RGT_x4/visualization/Set5')
if args.SR == 'x2':
temp_LR_folder_path = os.path.join(args.output_dir, f'temp_LR/X2')
video_output_path = replace_filename(video_output_path, '_x2')
result_temp = osp.join(args.root_path, f'results/test_RGT_x2/visualization/Set5')
temp_HR_folder_path = os.path.join(args.output_dir, f'temp_HR')
# create_temp_folder(result_temp)
create_temp_folder(temp_LR_folder_path)
create_temp_folder(temp_HR_folder_path)
cap = cv2.VideoCapture(args.input_dir)
if not cap.isOpened():
print("Error opening video file.")
return
t1 = time.time()
frame_count = 0
frames_to_process = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frames_to_process.append((frame_count, frame))
frame_count += 1
with ThreadPoolExecutor(max_workers = args.mul_numwork) as executor:
for frame_count, frame in frames_to_process:
executor.submit(process_frame, frame_count, frame, temp_LR_folder_path, temp_HR_folder_path, args.SR)
print("total frames:",frame_count)
print("fps :",cap.get(cv2.CAP_PROP_FPS))
t2 = time.time()
print('mul threads: ',t2 - t1,'s')
# progress all frames in video
image_sr(args)
t3 = time.time()
print('image super resolution: ',t3 - t2,'s')
# recover video form all frames
frame_files = sorted(os.listdir(result_temp), key=extract_number)
video_frames = [imageio.imread(os.path.join(result_temp, frame_file)) for frame_file in frame_files]
fps = cap.get(cv2.CAP_PROP_FPS)
imageio.mimwrite(video_output_path, video_frames, fps=fps, quality=9)
t4 = time.time()
print('tranformer frames to video: ',t4 - t3,'s')
# release all resources
cap.release()
delete_temp_folder(os.path.dirname(temp_LR_folder_path))
delete_temp_folder(temp_HR_folder_path)
delete_temp_folder(os.path.join(args.root_path, f'results'))
t5 = time.time()
print('delete time: ',t5 - t4,'s')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RGT for Video Super-Resolution")
# make sure you SR is match with the ckpt_path
parser.add_argument("--SR", type=str, choices=['x2', 'x4'], default='x4', help='image resolution')
parser.add_argument("--ckpt_path", type=str, default = "/remote-home/lzy/RGT/experiments/pretrained_models/RGT_x4.pth")
parser.add_argument("--root_path", type=str, default = "/remote-home/lzy/RGT")
parser.add_argument("--input_dir", type=str, default= "/remote-home/lzy/RGT/datasets/video/video_test1.mp4")
parser.add_argument("--output_dir", type=str, default= "/remote-home/lzy/RGT/datasets/video_output")
parser.add_argument("--mul_numwork", type=int, default = 16, help ='max_workers to execute Multi')
parser.add_argument("--use_chop", type= bool, default = True, help ='use_chop: True # True to save memory, if img too large')
args = parser.parse_args()
video_sr(args)