File size: 5,683 Bytes
76ccfaa b6668e8 76ccfaa b6668e8 76ccfaa b6668e8 76ccfaa b6668e8 76ccfaa b6668e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from typing import Generator, List, Iterable
import numpy as np
import tensorflow as tf
from huggingface_hub import snapshot_download
"""A wrapper class for running a frame interpolation based on the FILM model on TFHub
Usage:
interpolator = Interpolator()
result_batch = interpolator(image_batch_0, image_batch_1, batch_dt)
Where image_batch_1 and image_batch_2 are numpy tensors with TF standard
(B,H,W,C) layout, batch_dt is the sub-frame time in range [0..1], (B,) layout.
"""
FILM_REPO_ID = "leonelhs/film"
def _pad_to_align(x, align):
"""Pads image batch x so width and height divide by align.
Args:
x: Image batch to align.
align: Number to align to.
Returns:
1) An image padded so width % align == 0 and height % align == 0.
2) A bounding box that can be fed readily to tf.image.crop_to_bounding_box
to undo the padding.
"""
# Input checking.
assert np.ndim(x) == 4
assert align > 0, 'align must be a positive number.'
height, width = x.shape[-3:-1]
height_to_pad = (align - height % align) if height % align != 0 else 0
width_to_pad = (align - width % align) if width % align != 0 else 0
bbox_to_pad = {
'offset_height': height_to_pad // 2,
'offset_width': width_to_pad // 2,
'target_height': height + height_to_pad,
'target_width': width + width_to_pad
}
padded_x = tf.image.pad_to_bounding_box(x, **bbox_to_pad)
bbox_to_crop = {
'offset_height': height_to_pad // 2,
'offset_width': width_to_pad // 2,
'target_height': height,
'target_width': width
}
return padded_x, bbox_to_crop
class Interpolator:
"""A class for generating interpolated frames between two input frames.
Uses the Film model from TFHub
"""
def __init__(self, times_to_interpolate=6, align: int = 64) -> None:
"""Loads a saved model.
Args:
align: 'If >1, pad the input size so it divides with this before
inference.'
"""
self.times_to_interpolate = times_to_interpolate
model_path = snapshot_download(FILM_REPO_ID)
self._model = tf.saved_model.load(model_path)
# self._model = hub.load("https://tfhub.dev/google/film/1")
self._align = align
def __call__(self, x0: np.ndarray, x1: np.ndarray,
dt: np.ndarray) -> np.ndarray:
"""Generates an interpolated frame between given two batches of frames.
All inputs should be np.float32 datatype.
Args:
x0: First image batch. Dimensions: (batch_size, height, width, channels)
x1: Second image batch. Dimensions: (batch_size, height, width, channels)
dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,)
Returns:
The result with dimensions (batch_size, height, width, channels).
"""
if self._align is not None:
x0, bbox_to_crop = _pad_to_align(x0, self._align)
x1, _ = _pad_to_align(x1, self._align)
inputs = {'x0': x0, 'x1': x1, 'time': dt[..., np.newaxis]}
result = self._model(inputs, training=False)
image = result['image']
if self._align is not None:
image = tf.image.crop_to_bounding_box(image, **bbox_to_crop)
return image.numpy()
def preview_frames(self, frames: List[np.ndarray]):
time = np.array([0.5], dtype=np.float32)
media_input = {
'time': np.expand_dims(time, axis=0), # adding the batch dimension to the time
'x0': np.expand_dims(frames[0], axis=0), # adding the batch dimension to the image
'x1': np.expand_dims(frames[1], axis=0) # adding the batch dimension to the image
}
mid = self._model(media_input)
return [frames[0], mid['image'][0].numpy(), frames[1]]
def _recursive_generator(
frame1: np.ndarray, frame2: np.ndarray, num_recursions: int,
interpolator: Interpolator) -> Generator[np.ndarray, None, None]:
"""Splits halfway to repeatedly generate more frames.
Args:
frame1: Input image 1.
frame2: Input image 2.
num_recursions: How many times to interpolate the consecutive image pairs.
interpolator: The frame interpolator instance.
Yields:
The interpolated frames, including the first frame (frame1), but excluding
the final frame2.
"""
if num_recursions == 0:
yield frame1
else:
# Adds the batch dimension to all inputs before calling the interpolator,
# and remove it afterwards.
time = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
mid_frame = interpolator(np.expand_dims(frame1, axis=0), np.expand_dims(frame2, axis=0), time)[0]
yield from _recursive_generator(frame1, mid_frame, num_recursions - 1, interpolator)
yield from _recursive_generator(mid_frame, frame2, num_recursions - 1, interpolator)
def interpolate_recursively(
frames: List[np.ndarray], interpolator: Interpolator) -> Iterable[np.ndarray]:
"""Generates interpolated frames by repeatedly interpolating the midpoint.
Args:
frames: List of input frames. Expected shape (H, W, 3). The colors should be
in the range[0, 1] and in gamma space.
num_recursions: Number of times to do recursive midpoint
interpolation.
interpolator: The frame interpolation model to use.
Yields:
The interpolated frames (including the inputs).
"""
times_to_interpolate = interpolator.times_to_interpolate
n = len(frames)
for i in range(1, n):
yield from _recursive_generator(frames[i - 1], frames[i], times_to_interpolate, interpolator)
# Separately yield the final frame.
yield frames[-1]
|