Yue_tacotron2 / waveglow /convert_model.py
Ritori's picture
Upload 72 files
a722365
raw
history blame contribute delete
No virus
3.07 kB
import sys
import copy
import torch
def _check_model_old_version(model):
if hasattr(model.WN[0], 'res_layers') or hasattr(model.WN[0], 'cond_layers'):
return True
else:
return False
def _update_model_res_skip(old_model, new_model):
for idx in range(0, len(new_model.WN)):
wavenet = new_model.WN[idx]
n_channels = wavenet.n_channels
n_layers = wavenet.n_layers
wavenet.res_skip_layers = torch.nn.ModuleList()
for i in range(0, n_layers):
if i < n_layers - 1:
res_skip_channels = 2*n_channels
else:
res_skip_channels = n_channels
res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
skip_layer = torch.nn.utils.remove_weight_norm(wavenet.skip_layers[i])
if i < n_layers - 1:
res_layer = torch.nn.utils.remove_weight_norm(wavenet.res_layers[i])
res_skip_layer.weight = torch.nn.Parameter(torch.cat([res_layer.weight, skip_layer.weight]))
res_skip_layer.bias = torch.nn.Parameter(torch.cat([res_layer.bias, skip_layer.bias]))
else:
res_skip_layer.weight = torch.nn.Parameter(skip_layer.weight)
res_skip_layer.bias = torch.nn.Parameter(skip_layer.bias)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
wavenet.res_skip_layers.append(res_skip_layer)
del wavenet.res_layers
del wavenet.skip_layers
def _update_model_cond(old_model, new_model):
for idx in range(0, len(new_model.WN)):
wavenet = new_model.WN[idx]
n_channels = wavenet.n_channels
n_layers = wavenet.n_layers
n_mel_channels = wavenet.cond_layers[0].weight.shape[1]
cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
cond_layer_weight = []
cond_layer_bias = []
for i in range(0, n_layers):
_cond_layer = torch.nn.utils.remove_weight_norm(wavenet.cond_layers[i])
cond_layer_weight.append(_cond_layer.weight)
cond_layer_bias.append(_cond_layer.bias)
cond_layer.weight = torch.nn.Parameter(torch.cat(cond_layer_weight))
cond_layer.bias = torch.nn.Parameter(torch.cat(cond_layer_bias))
cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
wavenet.cond_layer = cond_layer
del wavenet.cond_layers
def update_model(old_model):
if not _check_model_old_version(old_model):
return old_model
new_model = copy.deepcopy(old_model)
if hasattr(old_model.WN[0], 'res_layers'):
_update_model_res_skip(old_model, new_model)
if hasattr(old_model.WN[0], 'cond_layers'):
_update_model_cond(old_model, new_model)
return new_model
if __name__ == '__main__':
old_model_path = sys.argv[1]
new_model_path = sys.argv[2]
model = torch.load(old_model_path, map_location='cpu')
model['model'] = update_model(model['model'])
torch.save(model, new_model_path)