File size: 5,241 Bytes
6e7b528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Minimal Reference implementation for the Frechet Video Distance (FVD).

FVD is a metric for the quality of video generation models. It is inspired by
the FID (Frechet Inception Distance) used for images, but uses a different
embedding to be better suitable for videos.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import six
import tensorflow.compat.v1 as tf
import tensorflow_gan as tfgan
import tensorflow_hub as hub


def preprocess(videos, target_resolution):
  """Runs some preprocessing on the videos for I3D model.

  Args:
    videos: <T>[batch_size, num_frames, height, width, depth] The videos to be
      preprocessed. We don't care about the specific dtype of the videos, it can
      be anything that tf.image.resize_bilinear accepts. Values are expected to
      be in the range 0-255.
    target_resolution: (width, height): target video resolution

  Returns:
    videos: <float32>[batch_size, num_frames, height, width, depth]
  """
  videos_shape = list(videos.shape)
  all_frames = tf.reshape(videos, [-1] + videos_shape[-3:])
  resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution)
  target_shape = [videos_shape[0], -1] + list(target_resolution) + [3]
  output_videos = tf.reshape(resized_videos, target_shape)
  scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1
  return scaled_videos


def _is_in_graph(tensor_name):
  """Checks whether a given tensor does exists in the graph."""
  try:
    tf.get_default_graph().get_tensor_by_name(tensor_name)
  except KeyError:
    return False
  return True


def create_id3_embedding(videos,warmup=False,batch_size=16):
  """Embeds the given videos using the Inflated 3D Convolution ne   twork.

  Downloads the graph of the I3D from tf.hub and adds it to the graph on the
  first call.

  Args:
    videos: <float32>[batch_size, num_frames, height=224, width=224, depth=3].
      Expected range is [-1, 1].

  Returns:
    embedding: <float32>[batch_size, embedding_size]. embedding_size depends
               on the model used.

  Raises:
    ValueError: when a provided embedding_layer is not supported.
  """

  # batch_size = 16
  module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1"


  # Making sure that we import the graph separately for
  # each different input video tensor.
  module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str(
      videos.name).replace(":", "_")



  assert_ops = [
      tf.Assert(
          tf.reduce_max(videos) <= 1.001,
          ["max value in frame is > 1", videos]),
      tf.Assert(
          tf.reduce_min(videos) >= -1.001,
          ["min value in frame is < -1", videos]),
      tf.assert_equal(
          tf.shape(videos)[0],
          batch_size, ["invalid frame batch size: ",
                       tf.shape(videos)],
          summarize=6),
  ]
  with tf.control_dependencies(assert_ops):
    videos = tf.identity(videos)

  module_scope = "%s_apply_default/" % module_name

  # To check whether the module has already been loaded into the graph, we look
  # for a given tensor name. If this tensor name exists, we assume the function
  # has been called before and the graph was imported. Otherwise we import it.
  # Note: in theory, the tensor could exist, but have wrong shapes.
  # This will happen if create_id3_embedding is called with a frames_placehoder
  # of wrong size/batch size, because even though that will throw a tf.Assert
  # on graph-execution time, it will insert the tensor (with wrong shape) into
  # the graph. This is why we need the following assert.
  if warmup:
      video_batch_size = int(videos.shape[0])
      assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}"
  tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
  if not _is_in_graph(tensor_name):
    i3d_model = hub.Module(module_spec, name=module_name)
    i3d_model(videos)

  # gets the kinetics-i3d-400-logits layer
  tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
  tensor = tf.get_default_graph().get_tensor_by_name(tensor_name)
  return tensor


def calculate_fvd(real_activations,
                  generated_activations):
  """Returns a list of ops that compute metrics as funcs of activations.

  Args:
    real_activations: <float32>[num_samples, embedding_size]
    generated_activations: <float32>[num_samples, embedding_size]

  Returns:
    A scalar that contains the requested FVD.
  """
  return tfgan.eval.frechet_classifier_distance_from_activations(
      real_activations, generated_activations)