|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Strategy util functions""" |
|
import tensorflow as tf |
|
|
|
|
|
def return_strategy(): |
|
physical_devices = tf.config.list_physical_devices("GPU") |
|
if len(physical_devices) == 0: |
|
return tf.distribute.OneDeviceStrategy(device="/cpu:0") |
|
elif len(physical_devices) == 1: |
|
return tf.distribute.OneDeviceStrategy(device="/gpu:0") |
|
else: |
|
return tf.distribute.MirroredStrategy() |
|
|
|
|
|
def calculate_3d_loss(y_gt, y_pred, loss_fn): |
|
"""Calculate 3d loss, normally it's mel-spectrogram loss.""" |
|
y_gt_T = tf.shape(y_gt)[1] |
|
y_pred_T = tf.shape(y_pred)[1] |
|
|
|
|
|
|
|
|
|
if y_gt_T > y_pred_T: |
|
y_gt = tf.slice(y_gt, [0, 0, 0], [-1, y_pred_T, -1]) |
|
elif y_pred_T > y_gt_T: |
|
y_pred = tf.slice(y_pred, [0, 0, 0], [-1, y_gt_T, -1]) |
|
|
|
loss = loss_fn(y_gt, y_pred) |
|
if isinstance(loss, tuple) is False: |
|
loss = tf.reduce_mean(loss, list(range(1, len(loss.shape)))) |
|
else: |
|
loss = list(loss) |
|
for i in range(len(loss)): |
|
loss[i] = tf.reduce_mean( |
|
loss[i], list(range(1, len(loss[i].shape))) |
|
) |
|
return loss |
|
|
|
|
|
def calculate_2d_loss(y_gt, y_pred, loss_fn): |
|
"""Calculate 2d loss, normally it's durrations/f0s/energys loss.""" |
|
y_gt_T = tf.shape(y_gt)[1] |
|
y_pred_T = tf.shape(y_pred)[1] |
|
|
|
|
|
|
|
|
|
if y_gt_T > y_pred_T: |
|
y_gt = tf.slice(y_gt, [0, 0], [-1, y_pred_T]) |
|
elif y_pred_T > y_gt_T: |
|
y_pred = tf.slice(y_pred, [0, 0], [-1, y_gt_T]) |
|
|
|
loss = loss_fn(y_gt, y_pred) |
|
if isinstance(loss, tuple) is False: |
|
loss = tf.reduce_mean(loss, list(range(1, len(loss.shape)))) |
|
else: |
|
loss = list(loss) |
|
for i in range(len(loss)): |
|
loss[i] = tf.reduce_mean( |
|
loss[i], list(range(1, len(loss[i].shape))) |
|
) |
|
|
|
return loss |
|
|