KevinGeng's picture
Upload model.py
6d26a9c
raw
history blame
6.93 kB
import torch
import torch.nn as nn
import fairseq
import os
import hydra
def load_ssl_model(cp_path):
ssl_model_type = cp_path.split("/")[-1]
wavlm = "WavLM" in ssl_model_type
if wavlm:
checkpoint = torch.load(cp_path)
cfg = WavLMConfig(checkpoint['cfg'])
ssl_model = WavLM(cfg)
ssl_model.load_state_dict(checkpoint['model'])
if 'Large' in ssl_model_type:
SSL_OUT_DIM = 1024
else:
SSL_OUT_DIM = 768
else:
if ssl_model_type == "wav2vec_small.pt":
SSL_OUT_DIM = 768
elif ssl_model_type in ["w2v_large_lv_fsh_swbd_cv.pt", "xlsr_53_56k.pt"]:
SSL_OUT_DIM = 1024
else:
print("*** ERROR *** SSL model type " + ssl_model_type + " not supported.")
exit()
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[cp_path]
)
ssl_model = model[0]
ssl_model.remove_pretraining_modules()
return SSL_model(ssl_model, SSL_OUT_DIM, wavlm)
class SSL_model(nn.Module):
def __init__(self,ssl_model,ssl_out_dim,wavlm) -> None:
super(SSL_model,self).__init__()
self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
self.WavLM = wavlm
def forward(self,batch):
wav = batch['wav']
wav = wav.squeeze(1) # [batches, audio_len]
if self.WavLM:
x = self.ssl_model.extract_features(wav)[0]
else:
res = self.ssl_model(wav, mask=False, features_only=True)
x = res["x"]
return {"ssl-feature":x}
def get_output_dim(self):
return self.ssl_out_dim
class PhonemeEncoder(nn.Module):
'''
PhonemeEncoder consists of an embedding layer, an LSTM layer, and a linear layer.
Args:
vocab_size: the size of the vocabulary
hidden_dim: the size of the hidden state of the LSTM
emb_dim: the size of the embedding layer
out_dim: the size of the output of the linear layer
n_lstm_layers: the number of LSTM layers
'''
def __init__(self, vocab_size, hidden_dim, emb_dim, out_dim,n_lstm_layers,with_reference=True) -> None:
super().__init__()
self.with_reference = with_reference
self.embedding = nn.Embedding(vocab_size, emb_dim)
self.encoder = nn.LSTM(emb_dim, hidden_dim,
num_layers=n_lstm_layers, dropout=0.1, bidirectional=True)
self.linear = nn.Sequential(
nn.Linear(hidden_dim + hidden_dim*self.with_reference, out_dim),
nn.ReLU()
)
self.out_dim = out_dim
def forward(self,batch):
seq = batch['phonemes']
lens = batch['phoneme_lens']
reference_seq = batch['reference']
reference_lens = batch['reference_lens']
emb = self.embedding(seq)
emb = torch.nn.utils.rnn.pack_padded_sequence(
emb, lens, batch_first=True, enforce_sorted=False)
_, (ht, _) = self.encoder(emb)
feature = ht[-1] + ht[0]
if self.with_reference:
if reference_seq==None or reference_lens ==None:
raise ValueError("reference_batch and reference_lens should not be None when with_reference is True")
reference_emb = self.embedding(reference_seq)
reference_emb = torch.nn.utils.rnn.pack_padded_sequence(
reference_emb, reference_lens, batch_first=True, enforce_sorted=False)
_, (ht_ref, _) = self.encoder(emb)
reference_feature = ht_ref[-1] + ht_ref[0]
feature = self.linear(torch.cat([feature,reference_feature],1))
else:
feature = self.linear(feature)
return {"phoneme-feature": feature}
def get_output_dim(self):
return self.out_dim
class DomainEmbedding(nn.Module):
def __init__(self,n_domains,domain_dim) -> None:
super().__init__()
self.embedding = nn.Embedding(n_domains,domain_dim)
self.output_dim = domain_dim
def forward(self, batch):
return {"domain-feature": self.embedding(batch['domains'])}
def get_output_dim(self):
return self.output_dim
class LDConditioner(nn.Module):
'''
Conditions ssl output by listener embedding
'''
def __init__(self,input_dim, judge_dim, num_judges=None):
super().__init__()
self.input_dim = input_dim
self.judge_dim = judge_dim
self.num_judges = num_judges
assert num_judges !=None
self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
# concat [self.output_layer, phoneme features]
self.decoder_rnn = nn.LSTM(
input_size = self.input_dim + self.judge_dim,
hidden_size = 512,
num_layers = 1,
batch_first = True,
bidirectional = True
) # linear?
self.out_dim = self.decoder_rnn.hidden_size*2
def get_output_dim(self):
return self.out_dim
def forward(self, x, batch):
judge_ids = batch['judge_id']
if 'phoneme-feature' in x.keys():
concatenated_feature = torch.cat((x['ssl-feature'], x['phoneme-feature'].unsqueeze(1).expand(-1,x['ssl-feature'].size(1) ,-1)),dim=2)
else:
concatenated_feature = x['ssl-feature']
if 'domain-feature' in x.keys():
concatenated_feature = torch.cat(
(
concatenated_feature,
x['domain-feature']
.unsqueeze(1)
.expand(-1, concatenated_feature.size(1), -1),
),
dim=2,
)
if judge_ids != None:
concatenated_feature = torch.cat(
(
concatenated_feature,
self.judge_embedding(judge_ids)
.unsqueeze(1)
.expand(-1, concatenated_feature.size(1), -1),
),
dim=2,
)
decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
return decoder_output
class Projection(nn.Module):
def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
super(Projection, self).__init__()
self.range_clipping = range_clipping
output_dim = 1
if range_clipping:
self.proj = nn.Tanh()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
activation,
nn.Dropout(0.3),
nn.Linear(hidden_dim, output_dim),
)
self.output_dim = output_dim
def forward(self, x, batch):
output = self.net(x)
# range clipping
if self.range_clipping:
return self.proj(output) * 2.0 + 3
else:
return output
def get_output_dim(self):
return self.output_dim