vishred18's picture
Upload 364 files
d5ee97c
# -*- coding: utf-8 -*-
# Copyright 2020 Minh Nguyen (@dathudeptrai)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Strategy util functions"""
import tensorflow as tf
def return_strategy():
physical_devices = tf.config.list_physical_devices("GPU")
if len(physical_devices) == 0:
return tf.distribute.OneDeviceStrategy(device="/cpu:0")
elif len(physical_devices) == 1:
return tf.distribute.OneDeviceStrategy(device="/gpu:0")
else:
return tf.distribute.MirroredStrategy()
def calculate_3d_loss(y_gt, y_pred, loss_fn):
"""Calculate 3d loss, normally it's mel-spectrogram loss."""
y_gt_T = tf.shape(y_gt)[1]
y_pred_T = tf.shape(y_pred)[1]
# there is a mismath length when training multiple GPU.
# we need slice the longer tensor to make sure the loss
# calculated correctly.
if y_gt_T > y_pred_T:
y_gt = tf.slice(y_gt, [0, 0, 0], [-1, y_pred_T, -1])
elif y_pred_T > y_gt_T:
y_pred = tf.slice(y_pred, [0, 0, 0], [-1, y_gt_T, -1])
loss = loss_fn(y_gt, y_pred)
if isinstance(loss, tuple) is False:
loss = tf.reduce_mean(loss, list(range(1, len(loss.shape)))) # shape = [B]
else:
loss = list(loss)
for i in range(len(loss)):
loss[i] = tf.reduce_mean(
loss[i], list(range(1, len(loss[i].shape)))
) # shape = [B]
return loss
def calculate_2d_loss(y_gt, y_pred, loss_fn):
"""Calculate 2d loss, normally it's durrations/f0s/energys loss."""
y_gt_T = tf.shape(y_gt)[1]
y_pred_T = tf.shape(y_pred)[1]
# there is a mismath length when training multiple GPU.
# we need slice the longer tensor to make sure the loss
# calculated correctly.
if y_gt_T > y_pred_T:
y_gt = tf.slice(y_gt, [0, 0], [-1, y_pred_T])
elif y_pred_T > y_gt_T:
y_pred = tf.slice(y_pred, [0, 0], [-1, y_gt_T])
loss = loss_fn(y_gt, y_pred)
if isinstance(loss, tuple) is False:
loss = tf.reduce_mean(loss, list(range(1, len(loss.shape)))) # shape = [B]
else:
loss = list(loss)
for i in range(len(loss)):
loss[i] = tf.reduce_mean(
loss[i], list(range(1, len(loss[i].shape)))
) # shape = [B]
return loss