|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from transformers import BertModel, BertConfig, BertTokenizer |
|
|
|
|
|
class CharEmbedding(nn.Module): |
|
def __init__(self, model_dir): |
|
super().__init__() |
|
self.tokenizer = BertTokenizer.from_pretrained(model_dir) |
|
self.bert_config = BertConfig.from_pretrained(model_dir) |
|
self.hidden_size = self.bert_config.hidden_size |
|
self.bert = BertModel(self.bert_config) |
|
self.proj = nn.Linear(self.hidden_size, 256) |
|
self.linear = nn.Linear(256, 3) |
|
|
|
def text2Token(self, text): |
|
token = self.tokenizer.tokenize(text) |
|
txtid = self.tokenizer.convert_tokens_to_ids(token) |
|
return txtid |
|
|
|
def forward(self, inputs_ids, inputs_masks, tokens_type_ids): |
|
out_seq = self.bert(input_ids=inputs_ids, |
|
attention_mask=inputs_masks, |
|
token_type_ids=tokens_type_ids)[0] |
|
out_seq = self.proj(out_seq) |
|
return out_seq |
|
|
|
|
|
class TTSProsody(object): |
|
def __init__(self, path, device): |
|
self.device = device |
|
self.char_model = CharEmbedding(path) |
|
self.char_model.load_state_dict( |
|
torch.load( |
|
os.path.join(path, 'prosody_model.pt'), |
|
map_location="cpu" |
|
), |
|
strict=False |
|
) |
|
self.char_model.eval() |
|
self.char_model.to(self.device) |
|
|
|
def get_char_embeds(self, text): |
|
input_ids = self.char_model.text2Token(text) |
|
input_masks = [1] * len(input_ids) |
|
type_ids = [0] * len(input_ids) |
|
input_ids = torch.LongTensor([input_ids]).to(self.device) |
|
input_masks = torch.LongTensor([input_masks]).to(self.device) |
|
type_ids = torch.LongTensor([type_ids]).to(self.device) |
|
|
|
with torch.no_grad(): |
|
char_embeds = self.char_model( |
|
input_ids, input_masks, type_ids).squeeze(0).cpu() |
|
return char_embeds |
|
|
|
def expand_for_phone(self, char_embeds, length): |
|
assert char_embeds.size(0) == len(length) |
|
expand_vecs = list() |
|
for vec, leng in zip(char_embeds, length): |
|
vec = vec.expand(leng, -1) |
|
expand_vecs.append(vec) |
|
expand_embeds = torch.cat(expand_vecs, 0) |
|
assert expand_embeds.size(0) == sum(length) |
|
return expand_embeds.numpy() |
|
|
|
|
|
if __name__ == "__main__": |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
prosody = TTSProsody('./bert/', device) |
|
while True: |
|
text = input("请输入文本:") |
|
prosody.get_char_embeds(text) |
|
|