File size: 2,652 Bytes
36fb9b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import BertModel, BertConfig, BertTokenizer


class CharEmbedding(nn.Module):
    def __init__(self, model_dir):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained(model_dir)
        self.bert_config = BertConfig.from_pretrained(model_dir)
        self.hidden_size = self.bert_config.hidden_size
        self.bert = BertModel(self.bert_config)
        self.proj = nn.Linear(self.hidden_size, 256)
        self.linear = nn.Linear(256, 3)

    def text2Token(self, text):
        token = self.tokenizer.tokenize(text)
        txtid = self.tokenizer.convert_tokens_to_ids(token)
        return txtid

    def forward(self, inputs_ids, inputs_masks, tokens_type_ids):
        out_seq = self.bert(input_ids=inputs_ids,
                            attention_mask=inputs_masks,
                            token_type_ids=tokens_type_ids)[0]
        out_seq = self.proj(out_seq)
        return out_seq


class TTSProsody(object):
    def __init__(self, path, device):
        self.device = device
        self.char_model = CharEmbedding(path)
        self.char_model.load_state_dict(
            torch.load(
                os.path.join(path, 'prosody_model.pt'),
                map_location="cpu"
            ),
            strict=False
        )
        self.char_model.eval()
        self.char_model.to(self.device)

    def get_char_embeds(self, text):
        input_ids = self.char_model.text2Token(text)
        input_masks = [1] * len(input_ids)
        type_ids = [0] * len(input_ids)
        input_ids = torch.LongTensor([input_ids]).to(self.device)
        input_masks = torch.LongTensor([input_masks]).to(self.device)
        type_ids = torch.LongTensor([type_ids]).to(self.device)

        with torch.no_grad():
            char_embeds = self.char_model(
                input_ids, input_masks, type_ids).squeeze(0).cpu()
        return char_embeds

    def expand_for_phone(self, char_embeds, length):  # length of phones for char
        assert char_embeds.size(0) == len(length)
        expand_vecs = list()
        for vec, leng in zip(char_embeds, length):
            vec = vec.expand(leng, -1)
            expand_vecs.append(vec)
        expand_embeds = torch.cat(expand_vecs, 0)
        assert expand_embeds.size(0) == sum(length)
        return expand_embeds.numpy()


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    prosody = TTSProsody('./bert/', device)
    while True:
        text = input("请输入文本:")
        prosody.get_char_embeds(text)