nakas's picture
Update audio_style_transfer/models/timedomain.py
354f5f3
raw
history blame
13.6 kB
"""NIPS2017 "Time Domain Neural Audio Style Transfer" code repository
Parag K. Mital
"""
import tensorflow as tf
import librosa
import numpy as np
from scipy.signal import hann
from audio_style_transfer import utils
import argparse
import glob
import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
import soundfile as sf
def chop(signal, hop_size=256, frame_size=512):
n_hops = len(signal) // hop_size
s = []
hann_win = hann(frame_size)
for hop_i in range(n_hops):
frame = signal[(hop_i * hop_size):(hop_i * hop_size + frame_size)]
frame = np.pad(frame, (0, frame_size - len(frame)), 'constant')
frame *= hann_win
s.append(frame)
s = np.array(s)
return s
def unchop(frames, hop_size=256, frame_size=512):
signal = np.zeros((frames.shape[0] * hop_size + frame_size,))
for hop_i, frame in enumerate(frames):
signal[(hop_i * hop_size):(hop_i * hop_size + frame_size)] += frame
return signal
def dft_np(signal, hop_size=256, fft_size=512):
s = chop(signal, hop_size, fft_size)
N = s.shape[-1]
k = np.reshape(
np.linspace(0.0, 2 * np.pi / N * (N // 2), N // 2), [1, N // 2])
x = np.reshape(np.linspace(0.0, N - 1, N), [N, 1])
freqs = np.dot(x, k)
real = np.dot(s, np.cos(freqs)) * (2.0 / N)
imag = np.dot(s, np.sin(freqs)) * (2.0 / N)
return real, imag
def idft_np(re, im, hop_size=256, fft_size=512):
N = re.shape[1] * 2
k = np.reshape(
np.linspace(0.0, 2 * np.pi / N * (N // 2), N // 2), [N // 2, 1])
x = np.reshape(np.linspace(0.0, N - 1, N), [1, N])
freqs = np.dot(k, x)
signal = np.zeros((re.shape[0] * hop_size + fft_size,))
recon = np.dot(re, np.cos(freqs)) + np.dot(im, np.sin(freqs))
for hop_i, frame in enumerate(recon):
signal[(hop_i * hop_size):(hop_i * hop_size + fft_size)] += frame
return signal
def unwrap(x):
return np.unwrap(x).astype(np.float32)
def instance_norm(x, epsilon=1e-5):
"""Instance Normalization.
See Ulyanov, D., Vedaldi, A., & Lempitsky, V. (2016).
Instance Normalization: The Missing Ingredient for Fast Stylization,
Retrieved from http://arxiv.org/abs/1607.08022
Parameters
----------
x : TYPE
Description
epsilon : float, optional
Description
"""
with tf.compat.v1.variable_scope('instance_norm'):
mean, var = tf.nn.moments(x=x, axes=[1, 2], keepdims=True)
scale = tf.compat.v1.get_variable(
name='scale',
shape=[x.get_shape()[-1]],
initializer=tf.compat.v1.truncated_normal_initializer(mean=1.0, stddev=0.02))
offset = tf.compat.v1.get_variable(
name='offset',
shape=[x.get_shape()[-1]],
initializer=tf.compat.v1.constant_initializer(0.0))
out = scale * tf.compat.v1.div(x - mean, tf.sqrt(var + epsilon)) + offset
return out
def compute_inputs(x, freqs, n_fft, n_frames, input_features, norm=False):
if norm:
norm_fn = instance_norm
else:
def norm_fn(x):
return x
freqs_tf = tf.constant(freqs, name="freqs", dtype='float32')
inputs = {}
with tf.compat.v1.variable_scope('real'):
inputs['real'] = norm_fn(tf.reshape(
tf.matmul(x, tf.cos(freqs_tf)), [1, 1, n_frames, n_fft // 2]))
with tf.compat.v1.variable_scope('imag'):
inputs['imag'] = norm_fn(tf.reshape(
tf.matmul(x, tf.sin(freqs_tf)), [1, 1, n_frames, n_fft // 2]))
with tf.compat.v1.variable_scope('mags'):
inputs['mags'] = norm_fn(tf.reshape(
tf.sqrt(
tf.maximum(1e-15, inputs['real'] * inputs['real'] + inputs[
'imag'] * inputs['imag'])), [1, 1, n_frames, n_fft // 2]))
with tf.compat.v1.variable_scope('phase'):
inputs['phase'] = norm_fn(tf.atan2(inputs['imag'], inputs['real']))
with tf.compat.v1.variable_scope('unwrapped'):
inputs['unwrapped'] = tf.compat.v1.py_func(
unwrap, [inputs['phase']], tf.float32)
with tf.compat.v1.variable_scope('unwrapped_difference'):
inputs['unwrapped_difference'] = (tf.slice(
inputs['unwrapped'],
[0, 0, 0, 1], [-1, -1, -1, n_fft // 2 - 1]) -
tf.slice(
inputs['unwrapped'],
[0, 0, 0, 0], [-1, -1, -1, n_fft // 2 - 1]))
if 'unwrapped_difference' in input_features:
for k, v in input_features:
if k is not 'unwrapped_difference':
inputs[k] = tf.slice(
v, [0, 0, 0, 0], [-1, -1, -1, n_fft // 2 - 1])
net = tf.concat([inputs[i] for i in input_features], 1)
return inputs, net
def compute_features(content,
style,
input_features,
norm=False,
stride=1,
n_layers=1,
n_filters=4096,
n_fft=1024,
k_h=1,
k_w=11):
n_frames = content.shape[0]
n_samples = content.shape[1]
content_tf = np.ascontiguousarray(content)
style_tf = np.ascontiguousarray(style)
g = tf.Graph()
kernels = []
content_features = []
style_features = []
config_proto = tf.compat.v1.ConfigProto()
with g.as_default(), g.device('/cpu:0'), tf.compat.v1.Session(config=config_proto) as sess:
x = tf.compat.v1.placeholder('float32', [n_frames, n_samples], name="x")
p = np.reshape(
np.linspace(0.0, n_samples - 1, n_samples), [n_samples, 1])
k = np.reshape(
np.linspace(0.0, 2 * np.pi / n_fft * (n_fft // 2), n_fft // 2),
[1, n_fft // 2])
freqs = np.dot(p, k)
inputs, net = compute_inputs(x, freqs, n_fft, n_frames, input_features, norm)
sess.run(tf.compat.v1.initialize_all_variables())
content_feature = net.eval(feed_dict={x: content_tf})
content_features.append(content_feature)
style_feature = inputs['mags'].eval(feed_dict={x: style_tf})
features = np.reshape(style_feature, (-1, n_fft // 2))
style_gram = np.matmul(features.T, features) / (n_frames)
style_features.append(style_gram)
for layer_i in range(n_layers):
if layer_i == 0:
std = np.sqrt(2) * np.sqrt(2.0 / (
(n_fft / 2 + n_filters) * k_w))
kernel = np.random.randn(k_h, k_w, n_fft // 2, n_filters) * std
else:
std = np.sqrt(2) * np.sqrt(2.0 / (
(n_filters + n_filters) * k_w))
kernel = np.random.randn(1, k_w, n_filters, n_filters) * std
kernels.append(kernel)
kernel_tf = tf.constant(
kernel, name="kernel{}".format(layer_i), dtype='float32')
conv = tf.nn.conv2d(
input=net,
filters=kernel_tf,
strides=[1, stride, stride, 1],
padding="VALID",
name="conv{}".format(layer_i))
net = tf.nn.relu(conv)
content_feature = net.eval(feed_dict={x: content_tf})
content_features.append(content_feature)
style_feature = net.eval(feed_dict={x: style_tf})
features = np.reshape(style_feature, (-1, n_filters))
style_gram = np.matmul(features.T, features) / (n_frames)
style_features.append(style_gram)
return content_features, style_features, kernels, freqs
def compute_stylization(kernels,
n_samples,
n_frames,
content_features,
style_gram,
freqs,
input_features,
norm=False,
stride=1,
n_layers=1,
n_fft=1024,
alpha=1e-4,
learning_rate=1e-3,
iterations=100,
optimizer='bfgs'):
result = None
with tf.Graph().as_default():
x = tf.Variable(
np.random.randn(n_frames, n_samples).astype(np.float32) * 1e-3,
name="x")
inputs, net = compute_inputs(x, freqs, n_fft, n_frames, input_features, norm)
content_loss = alpha * 2 * tf.nn.l2_loss(net - content_features[0])
feats = tf.reshape(inputs['mags'], (-1, n_fft // 2))
gram = tf.matmul(tf.transpose(a=feats), feats) / (n_frames)
style_loss = 2 * tf.nn.l2_loss(gram - style_gram[0])
for layer_i in range(n_layers):
kernel_tf = tf.constant(
kernels[layer_i],
name="kernel{}".format(layer_i),
dtype='float32')
conv = tf.nn.conv2d(
input=net,
filters=kernel_tf,
strides=[1, stride, stride, 1],
padding="VALID",
name="conv{}".format(layer_i))
net = tf.nn.relu(conv)
content_loss = content_loss + \
alpha * 2 * tf.nn.l2_loss(net - content_features[layer_i + 1])
_, height, width, number = map(lambda i: i, net.get_shape())
feats = tf.reshape(net, (-1, number))
gram = tf.matmul(tf.transpose(a=feats), feats) / (n_frames)
style_loss = style_loss + 2 * tf.nn.l2_loss(gram - style_gram[
layer_i + 1])
loss = content_loss + style_loss
if optimizer == 'bfgs':
opt = tf.compat.v1.train.AdamOptimizer(
learning_rate=learning_rate).minimize(loss)
# Optimization
with tf.compat.v1.Session() as sess:
sess.run(tf.compat.v1.initialize_all_variables())
print('Started optimization.')
for i in range(iterations):
s, c, l, _ = sess.run([style_loss, content_loss, loss, opt])
print('Style:', s, 'Content:', c, end='\r')
result = x.eval()
else:
opt = tf.compat.v1.train.AdamOptimizer(
learning_rate=learning_rate).minimize(loss)
# Optimization
with tf.compat.v1.Session() as sess:
sess.run(tf.compat.v1.initialize_all_variables())
print('Started optimization.')
for i in range(iterations):
s, c, l, _ = sess.run([style_loss, content_loss, loss, opt])
print('Style:', s, 'Content:', c, end='\r')
result = x.eval()
return result
def run(content_fname,
style_fname,
output_fname,
norm=False,
input_features=['real', 'imag', 'mags'],
n_fft=4096,
n_layers=1,
n_filters=4096,
hop_length=256,
alpha=0.05,
k_w=15,
k_h=3,
optimizer='bfgs',
stride=1,
iterations=300,
sr=22050):
frame_size = n_fft // 2
audio, fs = librosa.load(content_fname, sr=sr)
content = chop(audio, hop_size=hop_length, frame_size=frame_size)
audio, fs = librosa.load(style_fname, sr=sr)
style = chop(audio, hop_size=hop_length, frame_size=frame_size)
print ("loaded")
n_frames = min(content.shape[0], style.shape[0])
n_samples = min(content.shape[1], style.shape[1])
content = content[:n_frames, :n_samples]
style = style[:n_frames, :n_samples]
content_features, style_gram, kernels, freqs = compute_features(
content=content,
style=style,
input_features=input_features,
norm=norm,
stride=stride,
n_fft=n_fft,
n_layers=n_layers,
n_filters=n_filters,
k_w=k_w,
k_h=k_h)
result = compute_stylization(
kernels=kernels,
freqs=freqs,
input_features=input_features,
norm=norm,
n_samples=n_samples,
n_frames=n_frames,
n_fft=n_fft,
content_features=content_features,
style_gram=style_gram,
stride=stride,
n_layers=n_layers,
alpha=alpha,
optimizer=optimizer,
iterations=iterations)
s = unchop(result, hop_size=hop_length, frame_size=frame_size)
sf.write(output_fname, s, sr=sr)
s = utils.limiter(s)
sf.write(output_fname+'.limiter.wav', s, sr=sr)
def batch(content_path, style_path, output_path, model):
content_files = glob.glob('{}/*.wav'.format(content_path))
style_files = glob.glob('{}/*.wav'.format(style_path))
for content_fname in content_files:
for style_fname in style_files:
output_fname = '{}/{}+{}.wav'.format(output_path,
content_fname.split('/')[-1],
style_fname.split('/')[-1])
if os.path.exists(output_fname):
continue
run(content_fname, style_fname, output_fname, model)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--style', help='style file', required=True)
parser.add_argument('-c', '--content', help='content file', required=True)
parser.add_argument('-o', '--output', help='output file', required=True)
parser.add_argument(
'-m',
'--mode',
help='mode for training [single] or batch',
default='single')
args = vars(parser.parse_args())
if args['mode'] == 'single':
run(args['content'], args['style'], args['output'])
else:
batch(args['content'], args['style'], args['output'])