test / modules /rife /__init__.py
bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
#!/bin/env python
import _thread
import os
import time
from queue import Queue
import cv2
import numpy as np
import torch
from PIL import Image
from torch.nn import functional as F
from tqdm.rich import tqdm
from modules.rife.ssim import ssim_matlab
from modules.rife.model_rife import RifeModel
from modules import devices, shared
model_url = 'https://github.com/vladmandic/rife/raw/main/model/flownet-v46.pkl'
model: RifeModel = None
def load(model_path: str = 'rife/flownet-v46.pkl'):
global model # pylint: disable=global-statement
if model is None:
from modules import modelloader
model_dir = os.path.join(shared.models_path, 'RIFE')
model_path = modelloader.load_file_from_url(url=model_url, model_dir=model_dir, file_name='flownet-v46.pkl')
shared.log.debug(f'RIFE load model: file="{model_path}"')
model = RifeModel()
model.load_model(model_path, -1)
model.eval()
model.device()
def interpolate(images: list, count: int = 2, scale: float = 1.0, pad: int = 1, change: float = 0.3):
if images is None or len(images) < 2:
return []
if model is None:
load()
interpolated = []
h = images[0].height
w = images[0].width
t0 = time.time()
def write(buffer):
item = buffer.get()
while item is not None:
img = item[:, :, ::-1]
# image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
image = Image.fromarray(img)
item = buffer.get()
interpolated.append(image)
def execute(I0, I1, n):
if model.version >= 3.9:
res = []
for i in range(n):
res.append(model.inference(I0, I1, (i+1) * 1. / (n+1), scale))
return res
else:
middle = model.inference(I0, I1, scale)
if n == 1:
return [middle]
first_half = execute(I0, middle, n=n//2)
second_half = execute(middle, I1, n=n//2)
if n % 2:
return [*first_half, middle, *second_half]
else:
return [*first_half, *second_half]
def f_pad(img):
return F.pad(img, padding).to(devices.dtype) # pylint: disable=not-callable
tmp = max(128, int(128 / scale))
ph = ((h - 1) // tmp + 1) * tmp
pw = ((w - 1) // tmp + 1) * tmp
padding = (0, pw - w, 0, ph - h)
buffer = Queue(maxsize=8192)
_thread.start_new_thread(write, (buffer,))
frame = cv2.cvtColor(np.array(images[0]), cv2.COLOR_RGB2BGR)
for _i in range(pad): # fill starting frames
buffer.put(frame)
I1 = f_pad(torch.from_numpy(np.transpose(frame, (2,0,1))).to(devices.device, non_blocking=True).unsqueeze(0).float() / 255.)
with torch.no_grad():
with tqdm(total=len(images), desc='Interpolate', unit='frame') as pbar:
for image in images:
frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
I0 = I1
I1 = f_pad(torch.from_numpy(np.transpose(frame, (2,0,1))).to(devices.device, non_blocking=True).unsqueeze(0).float() / 255.)
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False).to(torch.float32)
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False).to(torch.float32)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
if ssim > 0.99: # skip duplicate frames
continue
if ssim < change:
output = []
for _i in range(pad): # fill frames if change rate is above threshold
output.append(I0)
for _i in range(pad):
output.append(I1)
else:
output = execute(I0, I1, count-1)
for mid in output:
mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
buffer.put(mid[:h, :w])
buffer.put(frame)
pbar.update(1)
for _i in range(pad): # fill ending frames
buffer.put(frame)
while not buffer.empty():
time.sleep(0.1)
t1 = time.time()
shared.log.info(f'RIFE interpolate: input={len(images)} frames={len(interpolated)} resolution={w}x{h} interpolate={count} scale={scale} pad={pad} change={change} time={round(t1 - t0, 2)}')
return interpolated