nakas's picture
github fork
2c448c3
raw
history blame
15.9 kB
"""NSynth & WaveNet Audio Style Transfer."""
import os
import glob
import librosa
import argparse
import numpy as np
import tensorflow as tf
from magenta.models.nsynth.wavenet import masked
from magenta.models.nsynth.utils import mu_law, inv_mu_law_numpy
from audio_style_transfer import utils
def compute_wavenet_encoder_features(content, style):
ae_hop_length = 512
ae_bottleneck_width = 16
ae_num_stages = 10
ae_num_layers = 30
ae_filter_length = 3
ae_width = 128
# Encode the source with 8-bit Mu-Law.
n_frames = content.shape[0]
n_samples = content.shape[1]
content_tf = np.ascontiguousarray(content)
style_tf = np.ascontiguousarray(style)
g = tf.Graph()
content_features = []
style_features = []
layers = []
with g.as_default(), g.device('/cpu:0'), tf.Session() as sess:
x = tf.placeholder('float32', [n_frames, n_samples], name="x")
x_quantized = mu_law(x)
x_scaled = tf.cast(x_quantized, tf.float32) / 128.0
x_scaled = tf.expand_dims(x_scaled, 2)
en = masked.conv1d(
x_scaled,
causal=False,
num_filters=ae_width,
filter_length=ae_filter_length,
name='ae_startconv')
for num_layer in range(ae_num_layers):
dilation = 2**(num_layer % ae_num_stages)
d = tf.nn.relu(en)
d = masked.conv1d(
d,
causal=False,
num_filters=ae_width,
filter_length=ae_filter_length,
dilation=dilation,
name='ae_dilatedconv_%d' % (num_layer + 1))
d = tf.nn.relu(d)
en += masked.conv1d(
d,
num_filters=ae_width,
filter_length=1,
name='ae_res_%d' % (num_layer + 1))
layers.append(en)
en = masked.conv1d(
en,
num_filters=ae_bottleneck_width,
filter_length=1,
name='ae_bottleneck')
en = masked.pool1d(en, ae_hop_length, name='ae_pool', mode='avg')
saver = tf.train.Saver()
saver.restore(sess, './model.ckpt-200000')
content_features = sess.run(layers, feed_dict={x: content_tf})
styles = sess.run(layers, feed_dict={x: style_tf})
for i, style_feature in enumerate(styles):
n_features = np.prod(layers[i].shape.as_list()[-1])
features = np.reshape(style_feature, (-1, n_features))
style_gram = np.matmul(features.T, features) / (n_samples *
n_frames)
style_features.append(style_gram)
return content_features, style_features
def compute_wavenet_encoder_stylization(n_samples,
n_frames,
content_features,
style_features,
alpha=1e-4,
learning_rate=1e-3,
iterations=100):
ae_style_layers = [1, 5]
ae_num_layers = 30
ae_num_stages = 10
ae_filter_length = 3
ae_width = 128
layers = []
with tf.Graph().as_default() as g, g.device('/cpu:0'), tf.Session() as sess:
x = tf.placeholder(
name="x", shape=(n_frames, n_samples, 1), dtype=tf.float32)
en = masked.conv1d(
x,
causal=False,
num_filters=ae_width,
filter_length=ae_filter_length,
name='ae_startconv')
for num_layer in range(ae_num_layers):
dilation = 2**(num_layer % ae_num_stages)
d = tf.nn.relu(en)
d = masked.conv1d(
d,
causal=False,
num_filters=ae_width,
filter_length=ae_filter_length,
dilation=dilation,
name='ae_dilatedconv_%d' % (num_layer + 1))
d = tf.nn.relu(d)
en += masked.conv1d(
d,
num_filters=ae_width,
filter_length=1,
name='ae_res_%d' % (num_layer + 1))
layer_i = tf.identity(en, name='layer_{}'.format(num_layer))
layers.append(layer_i)
saver = tf.train.Saver()
saver.restore(sess, './model.ckpt-200000')
sess.run(tf.initialize_all_variables())
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess, sess.graph_def, [en.name.replace(':0', '')] +
['layer_{}'.format(i) for i in range(ae_num_layers)])
with tf.Graph().as_default() as g, g.device('/cpu:0'), tf.Session() as sess:
x = tf.Variable(
np.random.randn(n_frames, n_samples, 1).astype(np.float32))
tf.import_graph_def(frozen_graph_def, input_map={'x:0': x})
content_loss = np.float32(0.0)
style_loss = np.float32(0.0)
for num_layer in ae_style_layers:
layer_i = g.get_tensor_by_name(name='import/layer_%d:0' %
(num_layer))
content_loss = content_loss + alpha * 2 * tf.nn.l2_loss(
layer_i - content_features[num_layer])
n_features = layer_i.shape.as_list()[-1]
features = tf.reshape(layer_i, (-1, n_features))
gram = tf.matmul(tf.transpose(features), features) / (n_frames *
n_samples)
style_loss = style_loss + 2 * tf.nn.l2_loss(gram - style_features[
num_layer])
loss = content_loss + style_loss
# Optimization
print('Started optimization.')
opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
var_list = tf.trainable_variables()
print(var_list)
sess.run(tf.initialize_all_variables())
for i in range(iterations):
s, c, layer, _ = sess.run([style_loss, content_loss, loss, opt])
print(i, '- Style:', s, 'Content:', c, end='\r')
result = x.eval()
result = inv_mu_law_numpy(result[..., 0] / result.max() * 128.0)
return result
def compute_wavenet_decoder_features(content, style):
num_stages = 10
num_layers = 30
filter_length = 3
width = 512
skip_width = 256
# Encode the source with 8-bit Mu-Law.
n_frames = content.shape[0]
n_samples = content.shape[1]
content_tf = np.ascontiguousarray(content)
style_tf = np.ascontiguousarray(style)
g = tf.Graph()
content_features = []
style_features = []
layers = []
with g.as_default(), g.device('/cpu:0'), tf.Session() as sess:
x = tf.placeholder('float32', [n_frames, n_samples], name="x")
x_quantized = mu_law(x)
x_scaled = tf.cast(x_quantized, tf.float32) / 128.0
x_scaled = tf.expand_dims(x_scaled, 2)
layer = x_scaled
layer = masked.conv1d(
layer, num_filters=width, filter_length=filter_length, name='startconv')
# Set up skip connections.
s = masked.conv1d(
layer, num_filters=skip_width, filter_length=1, name='skip_start')
# Residual blocks with skip connections.
for i in range(num_layers):
dilation = 2**(i % num_stages)
d = masked.conv1d(
layer,
num_filters=2 * width,
filter_length=filter_length,
dilation=dilation,
name='dilatedconv_%d' % (i + 1))
assert d.get_shape().as_list()[2] % 2 == 0
m = d.get_shape().as_list()[2] // 2
d_sigmoid = tf.sigmoid(d[:, :, :m])
d_tanh = tf.tanh(d[:, :, m:])
d = d_sigmoid * d_tanh
layer += masked.conv1d(
d, num_filters=width, filter_length=1, name='res_%d' % (i + 1))
s += masked.conv1d(
d,
num_filters=skip_width,
filter_length=1,
name='skip_%d' % (i + 1))
layers.append(s)
saver = tf.train.Saver()
saver.restore(sess, './model.ckpt-200000')
content_features = sess.run(layers, feed_dict={x: content_tf})
styles = sess.run(layers, feed_dict={x: style_tf})
for i, style_feature in enumerate(styles):
n_features = np.prod(layers[i].shape.as_list()[-1])
features = np.reshape(style_feature, (-1, n_features))
style_gram = np.matmul(features.T, features) / (n_samples *
n_frames)
style_features.append(style_gram)
return content_features, style_features
def compute_wavenet_decoder_stylization(n_samples,
n_frames,
content_features,
style_features,
alpha=1e-4,
learning_rate=1e-3,
iterations=100):
style_layers = [1, 5]
num_stages = 10
num_layers = 30
filter_length = 3
width = 512
skip_width = 256
layers = []
with tf.Graph().as_default() as g, g.device('/cpu:0'), tf.Session() as sess:
x = tf.placeholder(
name="x", shape=(n_frames, n_samples, 1), dtype=tf.float32)
layer = x
layer = masked.conv1d(
layer, num_filters=width, filter_length=filter_length, name='startconv')
# Set up skip connections.
s = masked.conv1d(
layer, num_filters=skip_width, filter_length=1, name='skip_start')
# Residual blocks with skip connections.
for i in range(num_layers):
dilation = 2**(i % num_stages)
d = masked.conv1d(
layer,
num_filters=2 * width,
filter_length=filter_length,
dilation=dilation,
name='dilatedconv_%d' % (i + 1))
assert d.get_shape().as_list()[2] % 2 == 0
m = d.get_shape().as_list()[2] // 2
d_sigmoid = tf.sigmoid(d[:, :, :m])
d_tanh = tf.tanh(d[:, :, m:])
d = d_sigmoid * d_tanh
layer += masked.conv1d(
d, num_filters=width, filter_length=1, name='res_%d' % (i + 1))
s += masked.conv1d(
d,
num_filters=skip_width,
filter_length=1,
name='skip_%d' % (i + 1))
layer_i = tf.identity(s, name='layer_{}'.format(num_layers))
layers.append(layer_i)
saver = tf.train.Saver()
saver.restore(sess, './model.ckpt-200000')
sess.run(tf.initialize_all_variables())
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess, sess.graph_def, [s.name.replace(':0', '')] +
['layer_{}'.format(i) for i in range(num_layers)])
with tf.Graph().as_default() as g, g.device('/cpu:0'), tf.Session() as sess:
x = tf.Variable(
np.random.randn(n_frames, n_samples, 1).astype(np.float32))
tf.import_graph_def(frozen_graph_def, input_map={'x:0': x})
content_loss = np.float32(0.0)
style_loss = np.float32(0.0)
for num_layer in style_layers:
layer_i = g.get_tensor_by_name(name='import/layer_%d:0' %
(num_layer))
content_loss = content_loss + alpha * 2 * tf.nn.l2_loss(
layer_i - content_features[num_layer])
n_features = layer_i.shape.as_list()[-1]
features = tf.reshape(layer_i, (-1, n_features))
gram = tf.matmul(tf.transpose(features), features) / (n_frames *
n_samples)
style_loss = style_loss + 2 * tf.nn.l2_loss(gram - style_features[
num_layer])
loss = content_loss + style_loss
# Optimization
print('Started optimization.')
opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
var_list = tf.trainable_variables()
print(var_list)
sess.run(tf.initialize_all_variables())
for i in range(iterations):
s, c, _ = sess.run([style_loss, content_loss, opt])
print(i, '- Style:', s, 'Content:', c, end='\r')
result = x.eval()
result = inv_mu_law_numpy(result[..., 0] / result.max() * 128.0)
return result
def run(content_fname,
style_fname,
output_path,
model,
iterations=100,
sr=16000,
hop_size=512,
frame_size=2048,
alpha=1e-3):
content, fs = librosa.load(content_fname, sr=sr)
style, fs = librosa.load(style_fname, sr=sr)
n_samples = (min(content.shape[0], style.shape[0]) // 512) * 512
content = utils.chop(content[:n_samples], hop_size, frame_size)
style = utils.chop(style[:n_samples], hop_size, frame_size)
if model == 'encoder':
content_features, style_features = compute_wavenet_encoder_features(
content=content, style=style)
result = compute_wavenet_encoder_stylization(
n_frames=content_features[0].shape[0],
n_samples=frame_size,
alpha=alpha,
content_features=content_features,
style_features=style_features,
iterations=iterations)
elif model == 'decoder':
content_features, style_features = compute_wavenet_decoder_features(
content=content, style=style)
result = compute_wavenet_decoder_stylization(
n_frames=content_features[0].shape[0],
n_samples=frame_size,
alpha=alpha,
content_features=content_features,
style_features=style_features,
iterations=iterations)
else:
raise ValueError('Unsupported model type: {}.'.format(model))
x = utils.unchop(result, hop_size, frame_size)
librosa.output.write_wav('prelimiter.wav', x, sr)
limited = utils.limiter(x)
output_fname = '{}/{}+{}.wav'.format(output_path,
content_fname.split('/')[-1],
style_fname.split('/')[-1])
librosa.output.write_wav(output_fname, limited, sr=sr)
def batch(content_path, style_path, output_path, model):
content_files = glob.glob('{}/*.wav'.format(content_path))
style_files = glob.glob('{}/*.wav'.format(style_path))
for content_fname in content_files:
for style_fname in style_files:
output_fname = '{}/{}+{}.wav'.format(output_path,
content_fname.split('/')[-1],
style_fname.split('/')[-1])
if os.path.exists(output_fname):
continue
run(content_fname, style_fname, output_fname, model)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-s', '--style', help='style file(s) location', required=True)
parser.add_argument(
'-c', '--content', help='content file(s) location', required=True)
parser.add_argument('-o', '--output', help='output path', required=True)
parser.add_argument(
'-m',
'--model',
help='model type: [encoder], or decoder',
default='encoder')
parser.add_argument(
'-t',
'--type',
help='mode for training [single] (point to files) or batch (point to path)',
default='single')
args = vars(parser.parse_args())
if args['model'] == 'single':
run(args['content'], args['style'], args['output'], args['model'])
else:
batch(args['content'], args['style'], args['output'], args['model'])