File size: 5,683 Bytes
76ccfaa
 
b6668e8
 
76ccfaa
b6668e8
 
 
 
 
 
 
 
 
 
76ccfaa
 
b6668e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76ccfaa
 
 
b6668e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76ccfaa
 
 
 
 
 
 
 
 
 
 
 
b6668e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from typing import Generator, List, Iterable

import numpy as np
import tensorflow as tf
from huggingface_hub import snapshot_download

"""A wrapper class for running a frame interpolation based on the FILM model on TFHub

Usage:
  interpolator = Interpolator()
  result_batch = interpolator(image_batch_0, image_batch_1, batch_dt)
  Where image_batch_1 and image_batch_2 are numpy tensors with TF standard
  (B,H,W,C) layout, batch_dt is the sub-frame time in range [0..1], (B,) layout.
"""

FILM_REPO_ID = "leonelhs/film"


def _pad_to_align(x, align):
    """Pads image batch x so width and height divide by align.

  Args:
    x: Image batch to align.
    align: Number to align to.

  Returns:
    1) An image padded so width % align == 0 and height % align == 0.
    2) A bounding box that can be fed readily to tf.image.crop_to_bounding_box
      to undo the padding.
  """
    # Input checking.
    assert np.ndim(x) == 4
    assert align > 0, 'align must be a positive number.'

    height, width = x.shape[-3:-1]
    height_to_pad = (align - height % align) if height % align != 0 else 0
    width_to_pad = (align - width % align) if width % align != 0 else 0

    bbox_to_pad = {
        'offset_height': height_to_pad // 2,
        'offset_width': width_to_pad // 2,
        'target_height': height + height_to_pad,
        'target_width': width + width_to_pad
    }
    padded_x = tf.image.pad_to_bounding_box(x, **bbox_to_pad)
    bbox_to_crop = {
        'offset_height': height_to_pad // 2,
        'offset_width': width_to_pad // 2,
        'target_height': height,
        'target_width': width
    }
    return padded_x, bbox_to_crop


class Interpolator:
    """A class for generating interpolated frames between two input frames.

  Uses the Film model from TFHub
  """

    def __init__(self, times_to_interpolate=6, align: int = 64) -> None:
        """Loads a saved model.

    Args:
      align: 'If >1, pad the input size so it divides with this before
        inference.'
    """
        self.times_to_interpolate = times_to_interpolate
        model_path = snapshot_download(FILM_REPO_ID)
        self._model = tf.saved_model.load(model_path)
        # self._model = hub.load("https://tfhub.dev/google/film/1")
        self._align = align

    def __call__(self, x0: np.ndarray, x1: np.ndarray,
                 dt: np.ndarray) -> np.ndarray:
        """Generates an interpolated frame between given two batches of frames.

    All inputs should be np.float32 datatype.

    Args:
      x0: First image batch. Dimensions: (batch_size, height, width, channels)
      x1: Second image batch. Dimensions: (batch_size, height, width, channels)
      dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,)

    Returns:
      The result with dimensions (batch_size, height, width, channels).
    """
        if self._align is not None:
            x0, bbox_to_crop = _pad_to_align(x0, self._align)
            x1, _ = _pad_to_align(x1, self._align)

        inputs = {'x0': x0, 'x1': x1, 'time': dt[..., np.newaxis]}
        result = self._model(inputs, training=False)
        image = result['image']

        if self._align is not None:
            image = tf.image.crop_to_bounding_box(image, **bbox_to_crop)
        return image.numpy()

    def preview_frames(self, frames: List[np.ndarray]):

        time = np.array([0.5], dtype=np.float32)

        media_input = {
            'time': np.expand_dims(time, axis=0),  # adding the batch dimension to the time
            'x0': np.expand_dims(frames[0], axis=0),  # adding the batch dimension to the image
            'x1': np.expand_dims(frames[1], axis=0)  # adding the batch dimension to the image
        }
        mid = self._model(media_input)
        return [frames[0], mid['image'][0].numpy(), frames[1]]


def _recursive_generator(
        frame1: np.ndarray, frame2: np.ndarray, num_recursions: int,
        interpolator: Interpolator) -> Generator[np.ndarray, None, None]:
    """Splits halfway to repeatedly generate more frames.

  Args:
    frame1: Input image 1.
    frame2: Input image 2.
    num_recursions: How many times to interpolate the consecutive image pairs.
    interpolator: The frame interpolator instance.

  Yields:
    The interpolated frames, including the first frame (frame1), but excluding
    the final frame2.
  """
    if num_recursions == 0:
        yield frame1
    else:
        # Adds the batch dimension to all inputs before calling the interpolator,
        # and remove it afterwards.
        time = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
        mid_frame = interpolator(np.expand_dims(frame1, axis=0), np.expand_dims(frame2, axis=0), time)[0]
        yield from _recursive_generator(frame1, mid_frame, num_recursions - 1, interpolator)
        yield from _recursive_generator(mid_frame, frame2, num_recursions - 1, interpolator)


def interpolate_recursively(
        frames: List[np.ndarray], interpolator: Interpolator) -> Iterable[np.ndarray]:
    """Generates interpolated frames by repeatedly interpolating the midpoint.

  Args:
    frames: List of input frames. Expected shape (H, W, 3). The colors should be
      in the range[0, 1] and in gamma space.
    num_recursions: Number of times to do recursive midpoint
      interpolation.
    interpolator: The frame interpolation model to use.

  Yields:
    The interpolated frames (including the inputs).
  """
    times_to_interpolate = interpolator.times_to_interpolate

    n = len(frames)
    for i in range(1, n):
        yield from _recursive_generator(frames[i - 1], frames[i], times_to_interpolate, interpolator)
    # Separately yield the final frame.
    yield frames[-1]