File size: 5,358 Bytes
bed01bd
 
 
d5d0921
d2b7e94
d5d0921
 
 
 
d2b7e94
d5d0921
 
 
bed01bd
 
 
 
d5d0921
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bed01bd
 
 
 
 
 
 
d5d0921
bed01bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import logging
from typing import Generator

import numpy as np

from modules.api.impl.handler.AudioHandler import AudioHandler
from modules.api.impl.model.audio_model import AdjustConfig
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
from modules.api.impl.model.enhancer_model import EnhancerConfig
from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
from modules.normalization import text_normalize
from modules.speaker import Speaker
from modules.synthesize_audio import synthesize_audio
from modules.synthesize_stream import synthesize_stream
from modules.utils.audio import apply_normalize, apply_prosody_to_audio_data

logger = logging.getLogger(__name__)


class TTSHandler(AudioHandler):
    def __init__(
        self,
        text_content: str,
        spk: Speaker,
        tts_config: ChatTTSConfig,
        infer_config: InferConfig,
        adjust_config: AdjustConfig,
        enhancer_config: EnhancerConfig,
    ):
        assert isinstance(text_content, str), "text_content should be str"
        assert isinstance(spk, Speaker), "spk should be Speaker"
        assert isinstance(
            tts_config, ChatTTSConfig
        ), "tts_config should be ChatTTSConfig"
        assert isinstance(
            infer_config, InferConfig
        ), "infer_config should be InferConfig"
        assert isinstance(
            adjust_config, AdjustConfig
        ), "adjest_config should be AdjustConfig"
        assert isinstance(
            enhancer_config, EnhancerConfig
        ), "enhancer_config should be EnhancerConfig"

        self.text_content = text_content
        self.spk = spk
        self.tts_config = tts_config
        self.infer_config = infer_config
        self.adjest_config = adjust_config
        self.enhancer_config = enhancer_config

        self.validate()

    def validate(self):
        # TODO params checker
        pass

    def enqueue(self) -> tuple[np.ndarray, int]:
        text = text_normalize(self.text_content)
        tts_config = self.tts_config
        infer_config = self.infer_config
        adjust_config = self.adjest_config
        enhancer_config = self.enhancer_config

        sample_rate, audio_data = synthesize_audio(
            text,
            spk=self.spk,
            temperature=tts_config.temperature,
            top_P=tts_config.top_p,
            top_K=tts_config.top_k,
            prompt1=tts_config.prompt1,
            prompt2=tts_config.prompt2,
            prefix=tts_config.prefix,
            infer_seed=infer_config.seed,
            batch_size=infer_config.batch_size,
            spliter_threshold=infer_config.spliter_threshold,
            end_of_sentence=infer_config.eos,
        )

        if enhancer_config.enabled:
            nfe = enhancer_config.nfe
            solver = enhancer_config.solver
            lambd = enhancer_config.lambd
            tau = enhancer_config.tau

            audio_data, sample_rate = apply_audio_enhance_full(
                audio_data=audio_data,
                sr=sample_rate,
                nfe=nfe,
                solver=solver,
                lambd=lambd,
                tau=tau,
            )

        audio_data = apply_prosody_to_audio_data(
            audio_data=audio_data,
            rate=adjust_config.speed_rate,
            pitch=adjust_config.pitch,
            volume=adjust_config.volume_gain_db,
            sr=sample_rate,
        )

        if adjust_config.normalize:
            sample_rate, audio_data = apply_normalize(
                audio_data=audio_data,
                headroom=adjust_config.headroom,
                sr=sample_rate,
            )

        return audio_data, sample_rate

    def enqueue_stream(self) -> Generator[tuple[np.ndarray, int], None, None]:
        text = text_normalize(self.text_content)
        tts_config = self.tts_config
        infer_config = self.infer_config
        adjust_config = self.adjest_config
        enhancer_config = self.enhancer_config

        if enhancer_config.enabled:
            logger.warning(
                "enhancer_config is enabled, but it is not supported in stream mode"
            )

        gen = synthesize_stream(
            text,
            spk=self.spk,
            temperature=tts_config.temperature,
            top_P=tts_config.top_p,
            top_K=tts_config.top_k,
            prompt1=tts_config.prompt1,
            prompt2=tts_config.prompt2,
            prefix=tts_config.prefix,
            infer_seed=infer_config.seed,
            spliter_threshold=infer_config.spliter_threshold,
            end_of_sentence=infer_config.eos,
        )

        # FIXME: 很奇怪,合并出来的音频每个 chunk 之前会有一段异常,暂时没有查出来是哪里的问题,可能是解码时候切割漏了?或者多了?
        for sr, wav in gen:

            wav = apply_prosody_to_audio_data(
                audio_data=wav,
                rate=adjust_config.speed_rate,
                pitch=adjust_config.pitch,
                volume=adjust_config.volume_gain_db,
                sr=sr,
            )

            if adjust_config.normalize:
                sr, wav = apply_normalize(
                    audio_data=wav,
                    headroom=adjust_config.headroom,
                    sr=sr,
                )

            yield wav, sr