File size: 2,096 Bytes
2061d64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A library for instantiating the model for training frame interpolation.

All models are expected to use three inputs: input image batches 'x0' and 'x1'
and 'time', the fractional time where the output should be generated.

The models are expected to output the prediction as a dictionary that contains
at least the predicted image batch as 'image' plus optional data for debug,
analysis or custom losses.
"""

import gin.tf
from ..models.film_net import interpolator as film_net_interpolator
from ..models.film_net import options as film_net_options

import tensorflow as tf


@gin.configurable('model')
def create_model(name: str) -> tf.keras.Model:
  """Creates the frame interpolation model based on given model name."""
  if name == 'film_net':
    return _create_film_net_model()  # pylint: disable=no-value-for-parameter
  else:
    raise ValueError(f'Model {name} not implemented.')


def _create_film_net_model() -> tf.keras.Model:
  """Creates the film_net interpolator."""
  # Options are gin-configured in the Options class directly.
  options = film_net_options.Options()

  x0 = tf.keras.Input(
      shape=(None, None, 3), batch_size=None, dtype=tf.float32, name='x0')
  x1 = tf.keras.Input(
      shape=(None, None, 3), batch_size=None, dtype=tf.float32, name='x1')
  time = tf.keras.Input(
      shape=(1,), batch_size=None, dtype=tf.float32, name='time')

  return film_net_interpolator.create_model(x0, x1, time, options)