|
|
|
import os |
|
import subprocess |
|
import torch |
|
import torchvision |
|
import imageio |
|
import glob |
|
|
|
from SAFMNPP import SAFMNPP |
|
|
|
def main(input_path, output_path, video_name, model): |
|
""" Script for testing video super resolution models. |
|
|
|
This script uses BasicVSR++ as a demo. Please replace the model loading |
|
and prediction sections with your own model. |
|
""" |
|
|
|
tmp_path = os.path.join('/frams', video_name[:-4]) |
|
os.makedirs(tmp_path, exist_ok=True) |
|
|
|
video_path = os.path.join(output_path, video_name) |
|
if os.path.exists(video_path): |
|
return |
|
|
|
input_video = torchvision.io.read_video( os.path.join(input_path, video_name)) |
|
normalized_frames = input_video[0].permute(0, 3, 1, 2) |
|
normalized_frames = torchvision.transforms.functional.convert_image_dtype(normalized_frames, torch.float32) |
|
input_data = normalized_frames.unsqueeze(0) |
|
|
|
device = torch.device('cuda', 0) |
|
|
|
|
|
print(f'total frames: {input_data.size(1)}') |
|
with torch.no_grad(): |
|
frame_idx = 0 |
|
for xi in input_data.chunk(100, dim=1): |
|
|
|
frames = model(xi.to(device)).detach_().cpu() |
|
for _, frame in enumerate(frames.squeeze(0).unbind(dim=0)): |
|
frame = frame.clamp(0, 1) |
|
frame = torchvision.transforms.functional.convert_image_dtype(frame, torch.uint8) |
|
frame = frame.squeeze(0).permute(1, 2, 0) |
|
|
|
if not os.path.exists(os.path.join(tmp_path, f'{frame_idx:08d}.png')): |
|
imageio.imwrite(os.path.join(tmp_path, f'{frame_idx:08d}.png'), frame.numpy()) |
|
print('save frames : ', os.path.join(tmp_path, f'{frame_idx:08d}.png')) |
|
else: |
|
print('exist frame : ', os.path.join(tmp_path, f'{frame_idx:08d}.png')) |
|
frame_idx+= 1 |
|
|
|
fps = input_video[2]['video_fps'] |
|
cmd = ( |
|
f"ffmpeg -r {fps} -i {tmp_path}/%08d.png " |
|
f"-c:v libx264 -crf 12 -preset veryfast {video_path}" |
|
) |
|
|
|
try: |
|
subprocess.run(cmd, shell=True, check=True) |
|
print("Video created successfully.") |
|
|
|
|
|
for frame_filename in glob.glob(os.path.join(tmp_path, '*.png')): |
|
os.remove(frame_filename) |
|
print(f"Deleted {frame_filename}") |
|
|
|
except subprocess.CalledProcessError as e: |
|
print(f"An error occurred while trying to run FFmpeg: {e}") |
|
|
|
|
|
if __name__ == '__main__': |
|
device = torch.device('cuda', 0) |
|
model = SAFMNPP(upscaling_factor=4).to(device) |
|
model_path = os.path.join(r'light_safmnpp.pth') |
|
model.load_state_dict(torch.load(model_path)['params'], strict=True) |
|
|
|
input_path = r'ValidationSet-1080p/bitstreams' |
|
output_path = r'Video_Output_4X' |
|
|
|
if not os.path.exists(output_path): |
|
os.makedirs(output_path) |
|
|
|
for video_name in os.listdir(input_path): |
|
main(input_path, output_path, video_name, model) |
|
print("Done", video_name) |
|
|
|
|