ai / training /config /film_net-VGG.gin
CHEN11102's picture
Upload 47 files
2061d64 verified
raw
history blame
2.82 kB
# Copyright 2022 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
model.name = 'film_net'
film_net.pyramid_levels = 7
film_net.fusion_pyramid_levels = 5
film_net.specialized_levels = 3
film_net.sub_levels = 4
film_net.flow_convs = [3, 3, 3, 3]
film_net.flow_filters = [32, 64, 128, 256]
film_net.filters = 64
training.learning_rate = 0.0001
training.learning_rate_decay_steps = 750000
training.learning_rate_decay_rate = 0.464158
training.learning_rate_staircase = True
training.num_steps = 3000000
# in the sweep
training_dataset.file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_train.tfrecord@200'
training_dataset.batch_size = 8
training_dataset.crop_size = 256
eval_datasets.batch_size = 1
eval_datasets.max_examples = -1
# eval_datasets.files = ['gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_test.tfrecord@3',
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/middlebury_other.tfrecord@3',
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/UCF101_interp_test.tfrecord@2',
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_2K.tfrecord@2',
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_4K.tfrecord@2']
# eval_datasets.names = ['vimeo90K', 'middlebury', 'ucf101', 'xiph2K', 'xiph4K']
eval_datasets.files = []
eval_datasets.names = []
# Training augmentation (in addition to random crop)
data_augmentation.names = ['random_image_rot90', 'random_flip', 'random_rotate', 'random_reverse']
# Loss functions
training_losses.loss_names = ['l1', 'vgg']
training_losses.loss_weight_schedules = [
@tf.keras.optimizers.schedules.PiecewiseConstantDecay,
@tf.keras.optimizers.schedules.PiecewiseConstantDecay]
# Decrease the weight of VGG loss at 1.5M steps.
training_losses.loss_weight_parameters = [
{'boundaries':[0], 'values':[1.0, 1.0]},
{'boundaries':[1500000], 'values':[1.0, 0.25]}]
test_losses.loss_names = ['l1', 'psnr', 'ssim']
test_losses.loss_weights = [1.0, 1.0, 1.0]
vgg.vgg_model_file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/pretrained_models/vgg/imagenet-vgg-verydeep-19.mat'