AniPortrait_official / src /utils /frame_interpolation.py
cocktailpeanut's picture
update
fea70af
raw
history blame contribute delete
No virus
3.45 kB
import os
import cv2
import numpy as np
import torch
import bisect
import shutil
if torch.backends.mps.is_available():
device = "mps"
#device = "cpu"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
def init_frame_interpolation_model():
print("Initializing frame interpolation model")
checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt")
model = torch.load(checkpoint_name, map_location='cpu')
model.eval()
model = model.half()
model = model.to(device=device)
return model
def batch_images_interpolation_tool(input_file, model, fps, inter_frames=1):
image_save_dir = input_file + '_tmp'
os.makedirs(image_save_dir, exist_ok=True)
input_img_list = os.listdir(input_file)
input_img_list.sort()
for idx in range(len(input_img_list)-1):
img1 = cv2.imread(os.path.join(input_file, input_img_list[idx]))
img2 = cv2.imread(os.path.join(input_file, input_img_list[idx+1]))
image1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255)
image2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255)
image1 = torch.from_numpy(image1).unsqueeze(0).permute(0, 3, 1, 2)
image2 = torch.from_numpy(image2).unsqueeze(0).permute(0, 3, 1, 2)
results = [image1, image2]
inter_frames = int(inter_frames)
idxes = [0, inter_frames + 1]
remains = list(range(1, inter_frames + 1))
splits = torch.linspace(0, 1, inter_frames + 2)
for _ in range(len(remains)):
starts = splits[idxes[:-1]]
ends = splits[idxes[1:]]
distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
matrix = torch.argmin(distances).item()
start_i, step = np.unravel_index(matrix, distances.shape)
end_i = start_i + 1
x0 = results[start_i]
x1 = results[end_i]
x0 = x0.half()
x1 = x1.half()
x0 = x0.to(device)
x1 = x1.to(device)
dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
with torch.no_grad():
prediction = model(x0, x1, dt)
insert_position = bisect.bisect_left(idxes, remains[step])
idxes.insert(insert_position, remains[step])
results.insert(insert_position, prediction.clamp(0, 1).cpu().float())
del remains[step]
frames = [(tensor[0] * 255).byte().flip(0).permute(1, 2, 0).numpy().copy() for tensor in results]
for sub_idx in range(len(frames)):
img_path = os.path.join(image_save_dir, f'{sub_idx+idx*(inter_frames+1):06d}.png')
cv2.imwrite(img_path, frames[sub_idx])
final_frames = []
final_img_list = os.listdir(image_save_dir)
final_img_list.sort()
for item in final_img_list:
final_frames.append(cv2.imread(os.path.join(image_save_dir, item)))
w, h = final_frames[0].shape[1::-1]
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
video_save_dir = input_file + '.mp4'
writer = cv2.VideoWriter(video_save_dir, fourcc, fps, (w, h))
for frame in final_frames:
writer.write(frame)
writer.release()
shutil.rmtree(image_save_dir)
return video_save_dir