|
from box import Box |
|
from pydub import AudioSegment |
|
from typing import List, Union |
|
from scipy.io.wavfile import write |
|
import io |
|
from modules.api.utils import calc_spk_style |
|
from modules.ssml_parser.SSMLParser import SSMLSegment, SSMLBreak, SSMLContext |
|
from modules.utils import rng |
|
from modules.utils.audio import time_stretch, pitch_shift |
|
from modules import generate_audio |
|
from modules.normalization import text_normalize |
|
import logging |
|
import json |
|
|
|
from modules.speaker import Speaker, speaker_mgr |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def audio_data_to_segment(audio_data, sr): |
|
byte_io = io.BytesIO() |
|
write(byte_io, rate=sr, data=audio_data) |
|
byte_io.seek(0) |
|
|
|
return AudioSegment.from_file(byte_io, format="wav") |
|
|
|
|
|
def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment: |
|
combined_audio = AudioSegment.empty() |
|
for segment in audio_segments: |
|
combined_audio += segment |
|
return combined_audio |
|
|
|
|
|
def apply_prosody( |
|
audio_segment: AudioSegment, rate: float, volume: float, pitch: float |
|
) -> AudioSegment: |
|
if rate != 1: |
|
audio_segment = time_stretch(audio_segment, rate) |
|
|
|
if volume != 0: |
|
audio_segment += volume |
|
|
|
if pitch != 0: |
|
audio_segment = pitch_shift(audio_segment, pitch) |
|
|
|
return audio_segment |
|
|
|
|
|
def to_number(value, t, default=0): |
|
try: |
|
number = t(value) |
|
return number |
|
except (ValueError, TypeError) as e: |
|
return default |
|
|
|
|
|
class TTSAudioSegment(Box): |
|
text: str |
|
temperature: float |
|
top_P: float |
|
top_K: int |
|
spk: int |
|
infer_seed: int |
|
prompt1: str |
|
prompt2: str |
|
prefix: str |
|
|
|
_type: str |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
class SynthesizeSegments: |
|
def __init__(self, batch_size: int = 8): |
|
self.batch_size = batch_size |
|
self.batch_default_spk_seed = rng.np_rng() |
|
self.batch_default_infer_seed = rng.np_rng() |
|
|
|
def segment_to_generate_params( |
|
self, segment: Union[SSMLSegment, SSMLBreak] |
|
) -> TTSAudioSegment: |
|
if isinstance(segment, SSMLBreak): |
|
return TTSAudioSegment(_type="break") |
|
|
|
if segment.get("params", None) is not None: |
|
return TTSAudioSegment(**segment.get("params")) |
|
|
|
text = segment.get("text", "") |
|
is_end = segment.get("is_end", False) |
|
|
|
text = str(text).strip() |
|
|
|
attrs = segment.attrs |
|
spk = attrs.spk |
|
style = attrs.style |
|
|
|
ss_params = calc_spk_style(spk, style) |
|
|
|
if "spk" in ss_params: |
|
spk = ss_params["spk"] |
|
|
|
seed = to_number(attrs.seed, int, ss_params.get("seed") or -1) |
|
top_k = to_number(attrs.top_k, int, None) |
|
top_p = to_number(attrs.top_p, float, None) |
|
temp = to_number(attrs.temp, float, None) |
|
|
|
prompt1 = attrs.prompt1 or ss_params.get("prompt1") |
|
prompt2 = attrs.prompt2 or ss_params.get("prompt2") |
|
prefix = attrs.prefix or ss_params.get("prefix") |
|
disable_normalize = attrs.get("normalize", "") == "False" |
|
|
|
seg = TTSAudioSegment( |
|
_type="voice", |
|
text=text, |
|
temperature=temp if temp is not None else 0.3, |
|
top_P=top_p if top_p is not None else 0.5, |
|
top_K=top_k if top_k is not None else 20, |
|
spk=spk if spk else -1, |
|
infer_seed=seed if seed else -1, |
|
prompt1=prompt1 if prompt1 else "", |
|
prompt2=prompt2 if prompt2 else "", |
|
prefix=prefix if prefix else "", |
|
) |
|
|
|
if not disable_normalize: |
|
seg.text = text_normalize(text, is_end=is_end) |
|
|
|
|
|
if seg.spk == -1: |
|
seg.spk = self.batch_default_spk_seed |
|
if seg.infer_seed == -1: |
|
seg.infer_seed = self.batch_default_infer_seed |
|
|
|
return seg |
|
|
|
def process_break_segments( |
|
self, |
|
src_segments: List[SSMLBreak], |
|
bucket_segments: List[SSMLBreak], |
|
audio_segments: List[AudioSegment], |
|
): |
|
for segment in bucket_segments: |
|
index = src_segments.index(segment) |
|
audio_segments[index] = AudioSegment.silent( |
|
duration=int(segment.attrs.duration) |
|
) |
|
|
|
def process_voice_segments( |
|
self, |
|
src_segments: List[SSMLSegment], |
|
bucket: List[SSMLSegment], |
|
audio_segments: List[AudioSegment], |
|
): |
|
for i in range(0, len(bucket), self.batch_size): |
|
batch = bucket[i : i + self.batch_size] |
|
param_arr = [self.segment_to_generate_params(segment) for segment in batch] |
|
texts = [params.text for params in param_arr] |
|
|
|
params = param_arr[0] |
|
audio_datas = generate_audio.generate_audio_batch( |
|
texts=texts, |
|
temperature=params.temperature, |
|
top_P=params.top_P, |
|
top_K=params.top_K, |
|
spk=params.spk, |
|
infer_seed=params.infer_seed, |
|
prompt1=params.prompt1, |
|
prompt2=params.prompt2, |
|
prefix=params.prefix, |
|
) |
|
for idx, segment in enumerate(batch): |
|
sr, audio_data = audio_datas[idx] |
|
rate = float(segment.get("rate", "1.0")) |
|
volume = float(segment.get("volume", "0")) |
|
pitch = float(segment.get("pitch", "0")) |
|
|
|
audio_segment = audio_data_to_segment(audio_data, sr) |
|
audio_segment = apply_prosody(audio_segment, rate, volume, pitch) |
|
original_index = src_segments.index(segment) |
|
audio_segments[original_index] = audio_segment |
|
|
|
def bucket_segments( |
|
self, segments: List[Union[SSMLSegment, SSMLBreak]] |
|
) -> List[List[Union[SSMLSegment, SSMLBreak]]]: |
|
buckets = {"<break>": []} |
|
for segment in segments: |
|
if isinstance(segment, SSMLBreak): |
|
buckets["<break>"].append(segment) |
|
continue |
|
|
|
params = self.segment_to_generate_params(segment) |
|
|
|
if isinstance(params.spk, Speaker): |
|
params.spk = str(params.spk.id) |
|
|
|
key = json.dumps( |
|
{k: v for k, v in params.items() if k != "text"}, sort_keys=True |
|
) |
|
if key not in buckets: |
|
buckets[key] = [] |
|
buckets[key].append(segment) |
|
|
|
return buckets |
|
|
|
def synthesize_segments( |
|
self, segments: List[Union[SSMLSegment, SSMLBreak]] |
|
) -> List[AudioSegment]: |
|
audio_segments = [None] * len(segments) |
|
buckets = self.bucket_segments(segments) |
|
|
|
break_segments = buckets.pop("<break>") |
|
self.process_break_segments(segments, break_segments, audio_segments) |
|
|
|
buckets = list(buckets.values()) |
|
|
|
for bucket in buckets: |
|
self.process_voice_segments(segments, bucket, audio_segments) |
|
|
|
return audio_segments |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
ctx1 = SSMLContext() |
|
ctx1.spk = 1 |
|
ctx1.seed = 42 |
|
ctx1.temp = 0.1 |
|
ctx2 = SSMLContext() |
|
ctx2.spk = 2 |
|
ctx2.seed = 42 |
|
ctx2.temp = 0.1 |
|
ssml_segments = [ |
|
SSMLSegment(text="大🍌,一条大🍌,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()), |
|
SSMLBreak(duration_ms=1000), |
|
SSMLSegment(text="大🍉,一个大🍉,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()), |
|
SSMLSegment(text="大🍊,一个大🍊,嘿,你的感觉真的很奇妙", attrs=ctx2.copy()), |
|
] |
|
|
|
synthesizer = SynthesizeSegments(batch_size=2) |
|
audio_segments = synthesizer.synthesize_segments(ssml_segments) |
|
print(audio_segments) |
|
combined_audio = combine_audio_segments(audio_segments) |
|
combined_audio.export("output.wav", format="wav") |
|
|