Spaces:
Running
on
Zero
Running
on
Zero
| from box import Box | |
| from pydub import AudioSegment | |
| from typing import List, Union | |
| from scipy.io.wavfile import write | |
| import io | |
| from modules.api.utils import calc_spk_style | |
| from modules.ssml_parser.SSMLParser import SSMLSegment, SSMLBreak, SSMLContext | |
| from modules.utils import rng | |
| from modules.utils.audio import time_stretch, pitch_shift | |
| from modules import generate_audio | |
| from modules.normalization import text_normalize | |
| import logging | |
| import json | |
| from modules.speaker import Speaker, speaker_mgr | |
| logger = logging.getLogger(__name__) | |
| def audio_data_to_segment(audio_data, sr): | |
| byte_io = io.BytesIO() | |
| write(byte_io, rate=sr, data=audio_data) | |
| byte_io.seek(0) | |
| return AudioSegment.from_file(byte_io, format="wav") | |
| def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment: | |
| combined_audio = AudioSegment.empty() | |
| for segment in audio_segments: | |
| combined_audio += segment | |
| return combined_audio | |
| def apply_prosody( | |
| audio_segment: AudioSegment, rate: float, volume: float, pitch: float | |
| ) -> AudioSegment: | |
| if rate != 1: | |
| audio_segment = time_stretch(audio_segment, rate) | |
| if volume != 0: | |
| audio_segment += volume | |
| if pitch != 0: | |
| audio_segment = pitch_shift(audio_segment, pitch) | |
| return audio_segment | |
| def to_number(value, t, default=0): | |
| try: | |
| number = t(value) | |
| return number | |
| except (ValueError, TypeError) as e: | |
| return default | |
| class TTSAudioSegment(Box): | |
| text: str | |
| temperature: float | |
| top_P: float | |
| top_K: int | |
| spk: int | |
| infer_seed: int | |
| prompt1: str | |
| prompt2: str | |
| prefix: str | |
| _type: str | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| class SynthesizeSegments: | |
| def __init__(self, batch_size: int = 8): | |
| self.batch_size = batch_size | |
| self.batch_default_spk_seed = rng.np_rng() | |
| self.batch_default_infer_seed = rng.np_rng() | |
| def segment_to_generate_params( | |
| self, segment: Union[SSMLSegment, SSMLBreak] | |
| ) -> TTSAudioSegment: | |
| if isinstance(segment, SSMLBreak): | |
| return TTSAudioSegment(_type="break") | |
| if segment.get("params", None) is not None: | |
| return TTSAudioSegment(**segment.get("params")) | |
| text = segment.get("text", "") | |
| is_end = segment.get("is_end", False) | |
| text = str(text).strip() | |
| attrs = segment.attrs | |
| spk = attrs.spk | |
| style = attrs.style | |
| ss_params = calc_spk_style(spk, style) | |
| if "spk" in ss_params: | |
| spk = ss_params["spk"] | |
| seed = to_number(attrs.seed, int, ss_params.get("seed") or -1) | |
| top_k = to_number(attrs.top_k, int, None) | |
| top_p = to_number(attrs.top_p, float, None) | |
| temp = to_number(attrs.temp, float, None) | |
| prompt1 = attrs.prompt1 or ss_params.get("prompt1") | |
| prompt2 = attrs.prompt2 or ss_params.get("prompt2") | |
| prefix = attrs.prefix or ss_params.get("prefix") | |
| disable_normalize = attrs.get("normalize", "") == "False" | |
| seg = TTSAudioSegment( | |
| _type="voice", | |
| text=text, | |
| temperature=temp if temp is not None else 0.3, | |
| top_P=top_p if top_p is not None else 0.5, | |
| top_K=top_k if top_k is not None else 20, | |
| spk=spk if spk else -1, | |
| infer_seed=seed if seed else -1, | |
| prompt1=prompt1 if prompt1 else "", | |
| prompt2=prompt2 if prompt2 else "", | |
| prefix=prefix if prefix else "", | |
| ) | |
| if not disable_normalize: | |
| seg.text = text_normalize(text, is_end=is_end) | |
| # NOTE 每个batch的默认seed保证前后一致即使是没设置spk的情况 | |
| if seg.spk == -1: | |
| seg.spk = self.batch_default_spk_seed | |
| if seg.infer_seed == -1: | |
| seg.infer_seed = self.batch_default_infer_seed | |
| return seg | |
| def process_break_segments( | |
| self, | |
| src_segments: List[SSMLBreak], | |
| bucket_segments: List[SSMLBreak], | |
| audio_segments: List[AudioSegment], | |
| ): | |
| for segment in bucket_segments: | |
| index = src_segments.index(segment) | |
| audio_segments[index] = AudioSegment.silent( | |
| duration=int(segment.attrs.duration) | |
| ) | |
| def process_voice_segments( | |
| self, | |
| src_segments: List[SSMLSegment], | |
| bucket: List[SSMLSegment], | |
| audio_segments: List[AudioSegment], | |
| ): | |
| for i in range(0, len(bucket), self.batch_size): | |
| batch = bucket[i : i + self.batch_size] | |
| param_arr = [self.segment_to_generate_params(segment) for segment in batch] | |
| texts = [params.text for params in param_arr] | |
| params = param_arr[0] | |
| audio_datas = generate_audio.generate_audio_batch( | |
| texts=texts, | |
| temperature=params.temperature, | |
| top_P=params.top_P, | |
| top_K=params.top_K, | |
| spk=params.spk, | |
| infer_seed=params.infer_seed, | |
| prompt1=params.prompt1, | |
| prompt2=params.prompt2, | |
| prefix=params.prefix, | |
| ) | |
| for idx, segment in enumerate(batch): | |
| sr, audio_data = audio_datas[idx] | |
| rate = float(segment.get("rate", "1.0")) | |
| volume = float(segment.get("volume", "0")) | |
| pitch = float(segment.get("pitch", "0")) | |
| audio_segment = audio_data_to_segment(audio_data, sr) | |
| audio_segment = apply_prosody(audio_segment, rate, volume, pitch) | |
| original_index = src_segments.index(segment) | |
| audio_segments[original_index] = audio_segment | |
| def bucket_segments( | |
| self, segments: List[Union[SSMLSegment, SSMLBreak]] | |
| ) -> List[List[Union[SSMLSegment, SSMLBreak]]]: | |
| buckets = {"<break>": []} | |
| for segment in segments: | |
| if isinstance(segment, SSMLBreak): | |
| buckets["<break>"].append(segment) | |
| continue | |
| params = self.segment_to_generate_params(segment) | |
| if isinstance(params.spk, Speaker): | |
| params.spk = str(params.spk.id) | |
| key = json.dumps( | |
| {k: v for k, v in params.items() if k != "text"}, sort_keys=True | |
| ) | |
| if key not in buckets: | |
| buckets[key] = [] | |
| buckets[key].append(segment) | |
| return buckets | |
| def synthesize_segments( | |
| self, segments: List[Union[SSMLSegment, SSMLBreak]] | |
| ) -> List[AudioSegment]: | |
| audio_segments = [None] * len(segments) | |
| buckets = self.bucket_segments(segments) | |
| break_segments = buckets.pop("<break>") | |
| self.process_break_segments(segments, break_segments, audio_segments) | |
| buckets = list(buckets.values()) | |
| for bucket in buckets: | |
| self.process_voice_segments(segments, bucket, audio_segments) | |
| return audio_segments | |
| # 示例使用 | |
| if __name__ == "__main__": | |
| ctx1 = SSMLContext() | |
| ctx1.spk = 1 | |
| ctx1.seed = 42 | |
| ctx1.temp = 0.1 | |
| ctx2 = SSMLContext() | |
| ctx2.spk = 2 | |
| ctx2.seed = 42 | |
| ctx2.temp = 0.1 | |
| ssml_segments = [ | |
| SSMLSegment(text="大🍌,一条大🍌,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()), | |
| SSMLBreak(duration_ms=1000), | |
| SSMLSegment(text="大🍉,一个大🍉,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()), | |
| SSMLSegment(text="大🍊,一个大🍊,嘿,你的感觉真的很奇妙", attrs=ctx2.copy()), | |
| ] | |
| synthesizer = SynthesizeSegments(batch_size=2) | |
| audio_segments = synthesizer.synthesize_segments(ssml_segments) | |
| print(audio_segments) | |
| combined_audio = combine_audio_segments(audio_segments) | |
| combined_audio.export("output.wav", format="wav") | |