File size: 3,449 Bytes
0c9dedf
 
 
 
 
 
 
ce91763
fea70af
 
ce91763
 
 
 
 
0c9dedf
 
 
 
 
 
 
ce91763
0c9dedf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce91763
 
0c9dedf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce91763
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import cv2
import numpy as np
import torch
import bisect
import shutil

if torch.backends.mps.is_available():
  device = "mps"
  #device = "cpu"
elif torch.cuda.is_available():
  device = "cuda"
else:
  device = "cpu"

def init_frame_interpolation_model():
    print("Initializing frame interpolation model")
    checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt")

    model = torch.load(checkpoint_name, map_location='cpu')
    model.eval()
    model = model.half()
    model = model.to(device=device)
    return model


def batch_images_interpolation_tool(input_file, model, fps, inter_frames=1):
    
    image_save_dir = input_file + '_tmp'
    os.makedirs(image_save_dir, exist_ok=True)
    
    input_img_list = os.listdir(input_file)
    input_img_list.sort()
    
    for idx in range(len(input_img_list)-1):
        img1 = cv2.imread(os.path.join(input_file, input_img_list[idx]))
        img2 = cv2.imread(os.path.join(input_file, input_img_list[idx+1]))
    
        image1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255)
        image2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255)
        image1 = torch.from_numpy(image1).unsqueeze(0).permute(0, 3, 1, 2)
        image2 = torch.from_numpy(image2).unsqueeze(0).permute(0, 3, 1, 2)

        results = [image1, image2]

        inter_frames = int(inter_frames)
        idxes = [0, inter_frames + 1]
        remains = list(range(1, inter_frames + 1))

        splits = torch.linspace(0, 1, inter_frames + 2)

        for _ in range(len(remains)):
            starts = splits[idxes[:-1]]
            ends = splits[idxes[1:]]
            distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
            matrix = torch.argmin(distances).item()
            start_i, step = np.unravel_index(matrix, distances.shape)
            end_i = start_i + 1

            x0 = results[start_i]
            x1 = results[end_i]

            x0 = x0.half()
            x1 = x1.half()
            x0 = x0.to(device)
            x1 = x1.to(device)

            dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])

            with torch.no_grad():
                prediction = model(x0, x1, dt)
            insert_position = bisect.bisect_left(idxes, remains[step])
            idxes.insert(insert_position, remains[step])
            results.insert(insert_position, prediction.clamp(0, 1).cpu().float())
            del remains[step]

        frames = [(tensor[0] * 255).byte().flip(0).permute(1, 2, 0).numpy().copy() for tensor in results]
    
        for sub_idx in range(len(frames)):
            img_path = os.path.join(image_save_dir, f'{sub_idx+idx*(inter_frames+1):06d}.png')
            cv2.imwrite(img_path, frames[sub_idx])
    
    final_frames = []
    final_img_list = os.listdir(image_save_dir)
    final_img_list.sort()
    for item in final_img_list:
        final_frames.append(cv2.imread(os.path.join(image_save_dir, item)))
    w, h = final_frames[0].shape[1::-1]
    fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
    video_save_dir = input_file + '.mp4'
    writer = cv2.VideoWriter(video_save_dir, fourcc, fps, (w, h))
    for frame in final_frames:
        writer.write(frame)
    writer.release()
    
    shutil.rmtree(image_save_dir)
    
    return video_save_dir