zejunyang commited on
Commit
0c9dedf
1 Parent(s): 202b7b1

update frame interpolation model

Browse files
Files changed (1) hide show
  1. src/utils/frame_interpolation.py +90 -0
src/utils/frame_interpolation.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import bisect
6
+ import shutil
7
+
8
+ def init_frame_interpolation_model():
9
+ print("Initializing frame interpolation model")
10
+ checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt")
11
+
12
+ model = torch.load(checkpoint_name, map_location='cpu')
13
+ model.eval()
14
+ model = model.half()
15
+ model = model.to(device="cuda")
16
+ return model
17
+
18
+
19
+ def batch_images_interpolation_tool(input_file, model, fps, inter_frames=1):
20
+
21
+ image_save_dir = input_file + '_tmp'
22
+ os.makedirs(image_save_dir, exist_ok=True)
23
+
24
+ input_img_list = os.listdir(input_file)
25
+ input_img_list.sort()
26
+
27
+ for idx in range(len(input_img_list)-1):
28
+ img1 = cv2.imread(os.path.join(input_file, input_img_list[idx]))
29
+ img2 = cv2.imread(os.path.join(input_file, input_img_list[idx+1]))
30
+
31
+ image1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255)
32
+ image2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255)
33
+ image1 = torch.from_numpy(image1).unsqueeze(0).permute(0, 3, 1, 2)
34
+ image2 = torch.from_numpy(image2).unsqueeze(0).permute(0, 3, 1, 2)
35
+
36
+ results = [image1, image2]
37
+
38
+ inter_frames = int(inter_frames)
39
+ idxes = [0, inter_frames + 1]
40
+ remains = list(range(1, inter_frames + 1))
41
+
42
+ splits = torch.linspace(0, 1, inter_frames + 2)
43
+
44
+ for _ in range(len(remains)):
45
+ starts = splits[idxes[:-1]]
46
+ ends = splits[idxes[1:]]
47
+ distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
48
+ matrix = torch.argmin(distances).item()
49
+ start_i, step = np.unravel_index(matrix, distances.shape)
50
+ end_i = start_i + 1
51
+
52
+ x0 = results[start_i]
53
+ x1 = results[end_i]
54
+
55
+ x0 = x0.half()
56
+ x1 = x1.half()
57
+ x0 = x0.cuda()
58
+ x1 = x1.cuda()
59
+
60
+ dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
61
+
62
+ with torch.no_grad():
63
+ prediction = model(x0, x1, dt)
64
+ insert_position = bisect.bisect_left(idxes, remains[step])
65
+ idxes.insert(insert_position, remains[step])
66
+ results.insert(insert_position, prediction.clamp(0, 1).cpu().float())
67
+ del remains[step]
68
+
69
+ frames = [(tensor[0] * 255).byte().flip(0).permute(1, 2, 0).numpy().copy() for tensor in results]
70
+
71
+ for sub_idx in range(len(frames)):
72
+ img_path = os.path.join(image_save_dir, f'{sub_idx+idx*(inter_frames+1):06d}.png')
73
+ cv2.imwrite(img_path, frames[sub_idx])
74
+
75
+ final_frames = []
76
+ final_img_list = os.listdir(image_save_dir)
77
+ final_img_list.sort()
78
+ for item in final_img_list:
79
+ final_frames.append(cv2.imread(os.path.join(image_save_dir, item)))
80
+ w, h = final_frames[0].shape[1::-1]
81
+ fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
82
+ video_save_dir = input_file + '.mp4'
83
+ writer = cv2.VideoWriter(video_save_dir, fourcc, fps, (w, h))
84
+ for frame in final_frames:
85
+ writer.write(frame)
86
+ writer.release()
87
+
88
+ shutil.rmtree(image_save_dir)
89
+
90
+ return video_save_dir