import os import cv2 import numpy as np import torch import bisect import shutil if torch.backends.mps.is_available(): device = "mps" #device = "cpu" elif torch.cuda.is_available(): device = "cuda" else: device = "cpu" def init_frame_interpolation_model(): print("Initializing frame interpolation model") checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt") model = torch.load(checkpoint_name, map_location='cpu') model.eval() model = model.half() model = model.to(device=device) return model def batch_images_interpolation_tool(input_file, model, fps, inter_frames=1): image_save_dir = input_file + '_tmp' os.makedirs(image_save_dir, exist_ok=True) input_img_list = os.listdir(input_file) input_img_list.sort() for idx in range(len(input_img_list)-1): img1 = cv2.imread(os.path.join(input_file, input_img_list[idx])) img2 = cv2.imread(os.path.join(input_file, input_img_list[idx+1])) image1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255) image2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255) image1 = torch.from_numpy(image1).unsqueeze(0).permute(0, 3, 1, 2) image2 = torch.from_numpy(image2).unsqueeze(0).permute(0, 3, 1, 2) results = [image1, image2] inter_frames = int(inter_frames) idxes = [0, inter_frames + 1] remains = list(range(1, inter_frames + 1)) splits = torch.linspace(0, 1, inter_frames + 2) for _ in range(len(remains)): starts = splits[idxes[:-1]] ends = splits[idxes[1:]] distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs() matrix = torch.argmin(distances).item() start_i, step = np.unravel_index(matrix, distances.shape) end_i = start_i + 1 x0 = results[start_i] x1 = results[end_i] x0 = x0.half() x1 = x1.half() x0 = x0.to(device) x1 = x1.to(device) dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]]) with torch.no_grad(): prediction = model(x0, x1, dt) insert_position = bisect.bisect_left(idxes, remains[step]) idxes.insert(insert_position, remains[step]) results.insert(insert_position, prediction.clamp(0, 1).cpu().float()) del remains[step] frames = [(tensor[0] * 255).byte().flip(0).permute(1, 2, 0).numpy().copy() for tensor in results] for sub_idx in range(len(frames)): img_path = os.path.join(image_save_dir, f'{sub_idx+idx*(inter_frames+1):06d}.png') cv2.imwrite(img_path, frames[sub_idx]) final_frames = [] final_img_list = os.listdir(image_save_dir) final_img_list.sort() for item in final_img_list: final_frames.append(cv2.imread(os.path.join(image_save_dir, item))) w, h = final_frames[0].shape[1::-1] fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') video_save_dir = input_file + '.mp4' writer = cv2.VideoWriter(video_save_dir, fourcc, fps, (w, h)) for frame in final_frames: writer.write(frame) writer.release() shutil.rmtree(image_save_dir) return video_save_dir