File size: 4,033 Bytes
b725c5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import json
import numpy as np
from text import text_to_sequence
from text.text_token_collation import phoneIDCollation
from models.tts.base.tts_dataset import (
    TTSDataset,
    TTSCollator,
    TTSTestDataset,
    TTSTestCollator,
)


class VITSDataset(TTSDataset):
    def __init__(self, cfg, dataset, is_valid):
        super().__init__(cfg, dataset, is_valid=is_valid)

    def __getitem__(self, index):
        single_feature = super().__getitem__(index)
        return single_feature

    def __len__(self):
        return super().__len__()


class VITSCollator(TTSCollator):
    """Zero-pads model inputs and targets based on number of frames per step"""

    def __init__(self, cfg):
        super().__init__(cfg)

    def __call__(self, batch):
        parsed_batch_features = super().__call__(batch)
        return parsed_batch_features


class VITSTestDataset(TTSTestDataset):
    def __init__(self, args, cfg):
        super().__init__(args, cfg)

        if cfg.preprocess.use_spkid:
            processed_data_dir = os.path.join(
                cfg.preprocess.processed_dir, args.dataset
            )
            spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
            with open(spk2id_path, "r") as f:
                self.spk2id = json.load(f)

            utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
            self.utt2spk = dict()
            with open(utt2spk_path, "r") as f:
                for line in f.readlines():
                    utt, spk = line.strip().split("\t")
                    self.utt2spk[utt] = spk

        if cfg.preprocess.use_text or cfg.preprocess.use_phone:
            self.utt2seq = {}
            for utt_info in self.metadata:
                dataset = utt_info["Dataset"]
                uid = utt_info["Uid"]
                utt = "{}_{}".format(dataset, uid)

                if cfg.preprocess.use_text:
                    text = utt_info["Text"]
                    sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
                elif cfg.preprocess.use_phone:
                    # load phoneme squence from phone file
                    phone_path = os.path.join(
                        processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
                    )
                    with open(phone_path, "r") as fin:
                        phones = fin.readlines()
                        assert len(phones) == 1
                        phones = phones[0].strip()
                    phones_seq = phones.split(" ")

                    phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
                    sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)

                self.utt2seq[utt] = sequence

    def __getitem__(self, index):
        utt_info = self.metadata[index]

        dataset = utt_info["Dataset"]
        uid = utt_info["Uid"]
        utt = "{}_{}".format(dataset, uid)

        single_feature = dict()

        if self.cfg.preprocess.use_spkid:
            single_feature["spk_id"] = np.array(
                [self.spk2id[self.utt2spk[utt]]], dtype=np.int32
            )

        if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
            single_feature["phone_seq"] = np.array(self.utt2seq[utt])
            single_feature["phone_len"] = len(self.utt2seq[utt])

        return single_feature

    def get_metadata(self):
        with open(self.metafile_path, "r", encoding="utf-8") as f:
            metadata = json.load(f)
        return metadata

    def __len__(self):
        return len(self.metadata)


class VITSTestCollator(TTSTestCollator):
    """Zero-pads model inputs and targets based on number of frames per step"""

    def __init__(self, cfg):
        self.cfg = cfg

    def __call__(self, batch):
        return super().__call__(batch)