File size: 4,468 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/bin/env python

import _thread
import os
import time
from queue import Queue
import cv2
import numpy as np
import torch
from PIL import Image
from torch.nn import functional as F
from tqdm.rich import tqdm
from modules.rife.ssim import ssim_matlab
from modules.rife.model_rife import RifeModel
from modules import devices, shared


model_url = 'https://github.com/vladmandic/rife/raw/main/model/flownet-v46.pkl'
model: RifeModel = None


def load(model_path: str = 'rife/flownet-v46.pkl'):
    global model # pylint: disable=global-statement
    if model is None:
        from modules import modelloader
        model_dir = os.path.join(shared.models_path, 'RIFE')
        model_path = modelloader.load_file_from_url(url=model_url, model_dir=model_dir, file_name='flownet-v46.pkl')
        shared.log.debug(f'RIFE load model: file="{model_path}"')
        model = RifeModel()
        model.load_model(model_path, -1)
        model.eval()
        model.device()


def interpolate(images: list, count: int = 2, scale: float = 1.0, pad: int = 1, change: float = 0.3):
    if images is None or len(images) < 2:
        return []
    if model is None:
        load()
    interpolated = []
    h = images[0].height
    w = images[0].width
    t0 = time.time()

    def write(buffer):
        item = buffer.get()
        while item is not None:
            img = item[:, :, ::-1]
            # image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            image = Image.fromarray(img)
            item = buffer.get()
            interpolated.append(image)

    def execute(I0, I1, n):
        if model.version >= 3.9:
            res = []
            for i in range(n):
                res.append(model.inference(I0, I1, (i+1) * 1. / (n+1), scale))
            return res
        else:
            middle = model.inference(I0, I1, scale)
            if n == 1:
                return [middle]
            first_half = execute(I0, middle, n=n//2)
            second_half = execute(middle, I1, n=n//2)
            if n % 2:
                return [*first_half, middle, *second_half]
            else:
                return [*first_half, *second_half]

    def f_pad(img):
        return F.pad(img, padding).to(devices.dtype) # pylint: disable=not-callable

    tmp = max(128, int(128 / scale))
    ph = ((h - 1) // tmp + 1) * tmp
    pw = ((w - 1) // tmp + 1) * tmp
    padding = (0, pw - w, 0, ph - h)
    buffer = Queue(maxsize=8192)
    _thread.start_new_thread(write, (buffer,))

    frame = cv2.cvtColor(np.array(images[0]), cv2.COLOR_RGB2BGR)
    for _i in range(pad): # fill starting frames
        buffer.put(frame)

    I1 = f_pad(torch.from_numpy(np.transpose(frame, (2,0,1))).to(devices.device, non_blocking=True).unsqueeze(0).float() / 255.)
    with torch.no_grad():
        with tqdm(total=len(images), desc='Interpolate', unit='frame') as pbar:
            for image in images:
                frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
                I0 = I1
                I1 = f_pad(torch.from_numpy(np.transpose(frame, (2,0,1))).to(devices.device, non_blocking=True).unsqueeze(0).float() / 255.)
                I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False).to(torch.float32)
                I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False).to(torch.float32)
                ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
                if ssim > 0.99: # skip duplicate frames
                    continue
                if ssim < change:
                    output = []
                    for _i in range(pad): # fill frames if change rate is above threshold
                        output.append(I0)
                    for _i in range(pad):
                        output.append(I1)
                else:
                    output = execute(I0, I1, count-1)
                for mid in output:
                    mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
                    buffer.put(mid[:h, :w])
                buffer.put(frame)
                pbar.update(1)

    for _i in range(pad): # fill ending frames
        buffer.put(frame)
    while not buffer.empty():
        time.sleep(0.1)
    t1 = time.time()
    shared.log.info(f'RIFE interpolate: input={len(images)} frames={len(interpolated)} resolution={w}x{h} interpolate={count} scale={scale} pad={pad} change={change} time={round(t1 - t0, 2)}')
    return interpolated