import copy import io import json import logging import re from typing import List, Union import numpy as np from box import Box from pydub import AudioSegment from scipy.io import wavfile from modules import generate_audio from modules.api.utils import calc_spk_style from modules.normalization import text_normalize from modules.SentenceSplitter import SentenceSplitter from modules.speaker import Speaker from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLContext, SSMLSegment from modules.utils import rng from modules.utils.audio import apply_prosody_to_audio_segment logger = logging.getLogger(__name__) def audio_data_to_segment_slow(audio_data, sr): byte_io = io.BytesIO() wavfile.write(byte_io, rate=sr, data=audio_data) byte_io.seek(0) return AudioSegment.from_file(byte_io, format="wav") def clip_audio(audio_data: np.ndarray, threshold: float = 0.99): audio_data = np.clip(audio_data, -threshold, threshold) return audio_data def normalize_audio(audio_data: np.ndarray, norm_factor: float = 0.8): max_amplitude = np.max(np.abs(audio_data)) if max_amplitude > 0: audio_data = audio_data / max_amplitude * norm_factor return audio_data def audio_data_to_segment(audio_data: np.ndarray, sr: int): """ optimize: https://github.com/lenML/ChatTTS-Forge/issues/57 """ audio_data = normalize_audio(audio_data) audio_data = clip_audio(audio_data) audio_data = (audio_data * 32767).astype(np.int16) audio_segment = AudioSegment( audio_data.tobytes(), frame_rate=sr, sample_width=audio_data.dtype.itemsize, channels=1, ) return audio_segment def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment: combined_audio = AudioSegment.empty() for segment in audio_segments: combined_audio += segment return combined_audio def to_number(value, t, default=0): try: number = t(value) return number except (ValueError, TypeError) as e: return default class TTSAudioSegment(Box): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._type = kwargs.get("_type", "voice") self.text = kwargs.get("text", "") self.temperature = kwargs.get("temperature", 0.3) self.top_P = kwargs.get("top_P", 0.5) self.top_K = kwargs.get("top_K", 20) self.spk = kwargs.get("spk", -1) self.infer_seed = kwargs.get("infer_seed", -1) self.prompt1 = kwargs.get("prompt1", "") self.prompt2 = kwargs.get("prompt2", "") self.prefix = kwargs.get("prefix", "") class SynthesizeSegments: def __init__(self, batch_size: int = 8, eos="", spliter_thr=100): self.batch_size = batch_size self.batch_default_spk_seed = rng.np_rng() self.batch_default_infer_seed = rng.np_rng() self.eos = eos self.spliter_thr = spliter_thr def segment_to_generate_params( self, segment: Union[SSMLSegment, SSMLBreak] ) -> TTSAudioSegment: if isinstance(segment, SSMLBreak): return TTSAudioSegment(_type="break") if segment.get("params", None) is not None: params = segment.get("params") text = segment.get("text", None) or segment.text or "" return TTSAudioSegment(**params, text=text) text = segment.get("text", None) or segment.text or "" is_end = segment.get("is_end", False) text = str(text).strip() attrs = segment.attrs spk = attrs.spk style = attrs.style ss_params = calc_spk_style(spk, style) if "spk" in ss_params: spk = ss_params["spk"] seed = to_number(attrs.seed, int, ss_params.get("seed") or -1) top_k = to_number(attrs.top_k, int, None) top_p = to_number(attrs.top_p, float, None) temp = to_number(attrs.temp, float, None) prompt1 = attrs.prompt1 or ss_params.get("prompt1") prompt2 = attrs.prompt2 or ss_params.get("prompt2") prefix = attrs.prefix or ss_params.get("prefix") disable_normalize = attrs.get("normalize", "") == "False" seg = TTSAudioSegment( _type="voice", text=text, temperature=temp if temp is not None else 0.3, top_P=top_p if top_p is not None else 0.5, top_K=top_k if top_k is not None else 20, spk=spk if spk else -1, infer_seed=seed if seed else -1, prompt1=prompt1 if prompt1 else "", prompt2=prompt2 if prompt2 else "", prefix=prefix if prefix else "", ) if not disable_normalize: seg.text = text_normalize(text, is_end=is_end) # NOTE 每个batch的默认seed保证前后一致即使是没设置spk的情况 if seg.spk == -1: seg.spk = self.batch_default_spk_seed if seg.infer_seed == -1: seg.infer_seed = self.batch_default_infer_seed return seg def process_break_segments( self, src_segments: List[SSMLBreak], bucket_segments: List[SSMLBreak], audio_segments: List[AudioSegment], ): for segment in bucket_segments: index = src_segments.index(segment) audio_segments[index] = AudioSegment.silent( duration=int(segment.attrs.duration) ) def process_voice_segments( self, src_segments: List[SSMLSegment], bucket: List[SSMLSegment], audio_segments: List[AudioSegment], ): for i in range(0, len(bucket), self.batch_size): batch = bucket[i : i + self.batch_size] param_arr = [self.segment_to_generate_params(segment) for segment in batch] def append_eos(text: str): text = text.strip() eos_arr = ["[uv_break]", "[v_break]", "[lbreak]", "[llbreak]"] has_eos = False for eos in eos_arr: if eos in text: has_eos = True break if not has_eos: text += self.eos return text # 这里会添加 end_of_text 到 text 之后 texts = [append_eos(params.text) for params in param_arr] params = param_arr[0] audio_datas = generate_audio.generate_audio_batch( texts=texts, temperature=params.temperature, top_P=params.top_P, top_K=params.top_K, spk=params.spk, infer_seed=params.infer_seed, prompt1=params.prompt1, prompt2=params.prompt2, prefix=params.prefix, ) for idx, segment in enumerate(batch): sr, audio_data = audio_datas[idx] rate = float(segment.get("rate", "1.0")) volume = float(segment.get("volume", "0")) pitch = float(segment.get("pitch", "0")) audio_segment = audio_data_to_segment(audio_data, sr) audio_segment = apply_prosody_to_audio_segment( audio_segment, rate=rate, volume=volume, pitch=pitch ) # compare by Box object original_index = src_segments.index(segment) audio_segments[original_index] = audio_segment def bucket_segments( self, segments: List[Union[SSMLSegment, SSMLBreak]] ) -> List[List[Union[SSMLSegment, SSMLBreak]]]: buckets = {"": []} for segment in segments: if isinstance(segment, SSMLBreak): buckets[""].append(segment) continue params = self.segment_to_generate_params(segment) if isinstance(params.spk, Speaker): params.spk = str(params.spk.id) key = json.dumps( {k: v for k, v in params.items() if k != "text"}, sort_keys=True ) if key not in buckets: buckets[key] = [] buckets[key].append(segment) return buckets def split_segments(self, segments: List[Union[SSMLSegment, SSMLBreak]]): """ 将 segments 中的 text 经过 spliter 处理成多个 segments """ spliter = SentenceSplitter(threshold=self.spliter_thr) ret_segments: List[Union[SSMLSegment, SSMLBreak]] = [] for segment in segments: if isinstance(segment, SSMLBreak): ret_segments.append(segment) continue text = segment.text if not text: continue sentences = spliter.parse(text) for sentence in sentences: seg = SSMLSegment( text=sentence, attrs=segment.attrs.copy(), params=copy.copy(segment.params), ) ret_segments.append(seg) setattr(seg, "_idx", len(ret_segments) - 1) def is_none_speak_segment(segment: SSMLSegment): text = segment.text.strip() regexp = r"\[[^\]]+?\]" text = re.sub(regexp, "", text) text = text.strip() if not text: return True return False # 将 none_speak 合并到前一个 speak segment for i in range(1, len(ret_segments)): if is_none_speak_segment(ret_segments[i]): ret_segments[i - 1].text += ret_segments[i].text ret_segments[i].text = "" # 移除空的 segment ret_segments = [seg for seg in ret_segments if seg.text.strip()] return ret_segments def synthesize_segments( self, segments: List[Union[SSMLSegment, SSMLBreak]] ) -> List[AudioSegment]: segments = self.split_segments(segments) audio_segments = [None] * len(segments) buckets = self.bucket_segments(segments) break_segments = buckets.pop("") self.process_break_segments(segments, break_segments, audio_segments) buckets = list(buckets.values()) for bucket in buckets: self.process_voice_segments(segments, bucket, audio_segments) return audio_segments # 示例使用 if __name__ == "__main__": ctx1 = SSMLContext() ctx1.spk = 1 ctx1.seed = 42 ctx1.temp = 0.1 ctx2 = SSMLContext() ctx2.spk = 2 ctx2.seed = 42 ctx2.temp = 0.1 ssml_segments = [ SSMLSegment(text="大🍌,一条大🍌,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()), SSMLBreak(duration_ms=1000), SSMLSegment(text="大🍉,一个大🍉,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()), SSMLSegment(text="大🍊,一个大🍊,嘿,你的感觉真的很奇妙", attrs=ctx2.copy()), ] synthesizer = SynthesizeSegments(batch_size=2) audio_segments = synthesizer.synthesize_segments(ssml_segments) print(audio_segments) combined_audio = combine_audio_segments(audio_segments) combined_audio.export("output.wav", format="wav")