Meloo's picture
Upload 4 files
188d68e verified
# (c) Meta Platforms, Inc. and affiliates.
import os
import subprocess
import torch
import torchvision
import imageio
import glob
from SAFMNPP import SAFMNPP
def main(input_path, output_path, video_name, model):
""" Script for testing video super resolution models.
This script uses BasicVSR++ as a demo. Please replace the model loading
and prediction sections with your own model.
"""
tmp_path = os.path.join('/frams', video_name[:-4])
os.makedirs(tmp_path, exist_ok=True)
video_path = os.path.join(output_path, video_name)
if os.path.exists(video_path):
return
input_video = torchvision.io.read_video( os.path.join(input_path, video_name)) #torchvision.io.read_video(args.input)
normalized_frames = input_video[0].permute(0, 3, 1, 2) # THWC to TCHW
normalized_frames = torchvision.transforms.functional.convert_image_dtype(normalized_frames, torch.float32)
input_data = normalized_frames.unsqueeze(0)
device = torch.device('cuda', 0)
#==========Replace the model loading and prediction in this section========
print(f'total frames: {input_data.size(1)}')
with torch.no_grad():
frame_idx = 0
for xi in input_data.chunk(100, dim=1):
# output.append()
frames = model(xi.to(device)).detach_().cpu()
for _, frame in enumerate(frames.squeeze(0).unbind(dim=0)):
frame = frame.clamp(0, 1) # Clamp values to be between 0 and 1
frame = torchvision.transforms.functional.convert_image_dtype(frame, torch.uint8)
frame = frame.squeeze(0).permute(1, 2, 0) # CTHW to HWC
if not os.path.exists(os.path.join(tmp_path, f'{frame_idx:08d}.png')):
imageio.imwrite(os.path.join(tmp_path, f'{frame_idx:08d}.png'), frame.numpy())
print('save frames : ', os.path.join(tmp_path, f'{frame_idx:08d}.png'))
else:
print('exist frame : ', os.path.join(tmp_path, f'{frame_idx:08d}.png'))
frame_idx+= 1
fps = input_video[2]['video_fps']
cmd = (
f"ffmpeg -r {fps} -i {tmp_path}/%08d.png "
f"-c:v libx264 -crf 12 -preset veryfast {video_path}"
)
try:
subprocess.run(cmd, shell=True, check=True)
print("Video created successfully.")
# 删除帧图片
for frame_filename in glob.glob(os.path.join(tmp_path, '*.png')):
os.remove(frame_filename)
print(f"Deleted {frame_filename}")
except subprocess.CalledProcessError as e:
print(f"An error occurred while trying to run FFmpeg: {e}")
if __name__ == '__main__':
device = torch.device('cuda', 0)
model = SAFMNPP(upscaling_factor=4).to(device)
model_path = os.path.join(r'light_safmnpp.pth')
model.load_state_dict(torch.load(model_path)['params'], strict=True)
input_path = r'ValidationSet-1080p/bitstreams'
output_path = r'Video_Output_4X'
if not os.path.exists(output_path):
os.makedirs(output_path)
for video_name in os.listdir(input_path):
main(input_path, output_path, video_name, model)
print("Done", video_name)