diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..55ccf26eb8256582c42cee1d0a4e6f7c852f34a0 --- /dev/null +++ b/app.py @@ -0,0 +1,567 @@ +import re +import gradio as gr +from tqdm import tqdm +from argparse import ArgumentParser +from typing import Literal, List, Tuple +import sys +import importlib.util +from datetime import datetime +import spaces +import torch +import numpy as np # 确保导入 numpy +import random # 确保导入 random +import s3tokenizer + +from soulxpodcast.models.soulxpodcast import SoulXPodcast +from soulxpodcast.config import Config, SoulXPodcastLLMConfig, SamplingParams +from soulxpodcast.utils.dataloader import ( + PodcastInferHandler, + SPK_DICT, TEXT_START, TEXT_END, AUDIO_START, TASK_PODCAST +) + +# ================================================ +# 示例音频路径 +# ================================================ +S1_PROMPT_WAV = "assets/audios/female_mandarin.wav" # 示例路径 +S2_PROMPT_WAV = "assets/audios/male_mandarin.wav" # 示例路径 + + +# ================================================ +# 示例数据 (gr.Examples) +# ================================================ +EXAMPLES_LIST = [ + # 示例 1:清空所有 + [ + None, "", "", None, "", "", "" + ], + # 示例 2:普通播客 + [ + S1_PROMPT_WAV, + "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。", + "", + S2_PROMPT_WAV, + "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?", + "", + "[S1] 哈喽,AI时代的冲浪先锋们!欢迎收听《AI生活进行时》。啊,一个充满了未来感,然后,还有一点点,<|laughter|>神经质的播客节目,我是主持人小希。\n[S2] 哎,大家好呀!我是能唠,爱唠,天天都想唠的唠嗑!\n[S1] 最近活得特别赛博朋克哈!以前老是觉得AI是科幻片儿里的,<|sigh|> 现在,现在连我妈都用AI写广场舞文案了。\n[S2] 这个例子很生动啊。是的,特别是生成式AI哈,感觉都要炸了! 诶,那我们今天就聊聊AI是怎么走进我们的生活的哈!", + ], + # 示例 3:四川播客 + [ + S1_PROMPT_WAV, + "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。", + "<|Sichuan|>要得要得!前头几个耍洋盘,我后脚就背起铺盖卷去景德镇耍泥巴,巴适得喊老天爷!", + S2_PROMPT_WAV, + "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?", + "<|Sichuan|>哎哟喂,这个搞反了噻!黑神话里头唱曲子的王二浪早八百年就在黄土高坡吼秦腔喽,游戏组专门跑切录的原汤原水,听得人汗毛儿都立起来!", + "[S1] <|Sichuan|>各位《巴适得板》的听众些,大家好噻!我是你们主持人晶晶。今儿天气硬是巴适,不晓得大家是在赶路嘛,还是茶都泡起咯,准备跟我们好生摆一哈龙门阵喃?\n[S2] <|Sichuan|>晶晶好哦,大家安逸噻!我是李老倌。你刚开口就川味十足,摆龙门阵几个字一甩出来,我鼻子头都闻到茶香跟火锅香咯!\n[S1] <|Sichuan|>就是得嘛!李老倌,我前些天带个外地朋友切人民公园鹤鸣茶社坐了一哈。他硬是搞不醒豁,为啥子我们一堆人围到杯茶就可以吹一下午壳子,从隔壁子王嬢嬢娃儿耍朋友,扯到美国大选,中间还掺几盘斗地主。他说我们四川人简直是把摸鱼刻进骨子里头咯!\n[S2] <|Sichuan|>你那个朋友说得倒是有点儿趣,但他莫看到精髓噻。摆龙门阵哪是摸鱼嘛,这是我们川渝人特有的交际方式,更是一种活法。外省人天天说的松弛感,根根儿就在这龙门阵里头。今天我们就要好生摆一哈,为啥子四川人活得这么舒坦。就先从茶馆这个老窝子说起,看它咋个成了我们四川人的魂儿!", + ], + # 示例 4:粤语播客 + [ + S1_PROMPT_WAV, + "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。", + "<|Yue|>真係冇讲错啊!攀山滑雪嘅语言专家几巴闭,都唔及我听日拖成副身家去景德镇玩泥巴,呢铺真系发哂白日梦咯!", + S2_PROMPT_WAV, + "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?", + "<|Yue|>咪搞错啊!陕北民谣响度唱咗几十年,黑神话边有咁大面啊?你估佢哋抄游戏咩!", + "[S1] <|Yue|>哈囉大家好啊,歡迎收聽我哋嘅節目。喂,我今日想問你樣嘢啊,你覺唔覺得,嗯,而家揸電動車,最煩,最煩嘅一樣嘢係咩啊?\n[S2] <|Yue|>梗係充電啦。大佬啊,搵個位都已經好煩,搵到個位仲要喺度等,你話快極都要半個鐘一個鐘,真係,有時諗起都覺得好冇癮。\n[S1] <|Yue|>係咪先。如果我而家同你講,充電可以快到同入油差唔多時間,你信唔信先?喂你平時喺油站入滿一缸油,要幾耐啊?五六分鐘?\n[S2] <|Yue|>差唔多啦,七八分鐘,點都走得啦。電車喎,可以做到咁快?你咪玩啦。", + ], + # 示例 5:河南播客 + [ + S1_PROMPT_WAV, + "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。", + "<|Henan|>俺这不是怕恁路上不得劲儿嘛!那景德镇瓷泥可娇贵着哩,得先拿咱河南人这实诚劲儿给它揉透喽。", + S2_PROMPT_WAV, + "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?", + "<|Henan|>恁这想法真闹挺!陕北民谣比黑神话早几百年都有了,咱可不兴这弄颠倒啊,中不?恁这想法真闹挺!那陕北民谣在黄土高坡响了几百年,咋能说是跟黑神话学的咧?咱得把这事儿捋直喽,中不中!", + "[S1] <|Henan|>哎,大家好啊,欢迎收听咱这一期嘞《瞎聊呗,就这么说》,我是恁嘞老朋友,燕子。\n[S2] <|Henan|>大家好,我是老张。燕子啊,今儿瞅瞅你这个劲儿,咋着,是有啥可得劲嘞事儿想跟咱唠唠?\n[S1] <|Henan|>哎哟,老张,你咋恁懂我嘞!我跟你说啊,最近我刷手机,老是刷住些可逗嘞方言视频,特别是咱河南话,咦~我哩个乖乖,一听我都憋不住笑,咋说嘞,得劲儿哩很,跟回到家一样。\n[S2] <|Henan|>你这回可算说到根儿上了!河南话,咱往大处说说,中原官话,它真嘞是有一股劲儿搁里头。它可不光是说话,它脊梁骨后头藏嘞,是咱一整套、鲜鲜活活嘞过法儿,一种活人嘞道理。\n[S1] <|Henan|>活人嘞道理?哎,这你这一说,我嘞兴致“腾”一下就上来啦!觉住咱这嗑儿,一下儿从搞笑视频蹿到文化顶上了。那你赶紧给我白话白话,这里头到底有啥道道儿?我特别想知道——为啥一提起咱河南人,好些人脑子里“蹦”出来嘞头一个词儿,就是实在?这个实在,骨子里到底是啥嘞?", + ], +] + + +# ================================================ +# SoulX-Podcast Model +# ================================================ +model: SoulXPodcast = None +dataset: PodcastInferHandler = None +def initiate_model(config: Config, enable_tn: bool=False): + global model + if model is None: + model = SoulXPodcast(config) + + global dataset + if dataset is None: + dataset = PodcastInferHandler(model.llm.tokenizer, None, config) + +# ================================================ +# Gradio +# ================================================ + +_i18n_key2lang_dict = dict( + # Speaker1 Prompt + spk1_prompt_audio_label=dict( + en="Speaker 1 Prompt Audio", + zh="说话人 1 参考语音", + ), + spk1_prompt_text_label=dict( + en="Speaker 1 Prompt Text", + zh="说话人 1 参考文本", + ), + spk1_prompt_text_placeholder=dict( + en="text of speaker 1 Prompt audio.", + zh="说话人 1 参考文本", + ), + spk1_prompt_cot_text_label=dict( + en="Speaker 1 Prompt COT Text", + zh="说话人 1 参考推理链文本", + ), + spk1_prompt_cot_text_placeholder=dict( + en="Dialect prompt cot text with prefix: <|Sichuan|>/<|Yue|>/<|Henan|> ", + zh="带前缀方言提示词思维链文本,前缀如下:<|Sichuan|>/<|Yue|>/<|Henan|>,如:<|Sichuan|>走嘛,切吃那家新开的麻辣烫,听别个说味道硬是霸道得很,好吃到不摆了,去晚了还得排队!", + ), + # Speaker2 Prompt + spk2_prompt_audio_label=dict( + en="Speaker 2 Prompt Audio", + zh="说话人 2 参考语音", + ), + spk2_prompt_text_label=dict( + en="Speaker 2 Prompt Text", + zh="说话人 2 参考文本", + ), + spk2_prompt_text_placeholder=dict( + en="[S2] text of speaker 2 prompt audio.", + zh="[S2] 说话人 2 参考文本", + ), + spk2_prompt_cot_text_label=dict( + en="Speaker 2 Prompt COT Text", + zh="说话人 2 参考推理链文本", + ), + spk2_prompt_cot_text_placeholder=dict( + en="Dialect prompt cot text with prefix: <|Sichuan|>/<|Yue|>/<|Henan|> ", + zh="带前缀方言提示词思维链文本,前缀如下:<|Sichuan|>/<|Yue|>/<|Henan|>,如:<|Sichuan|>走嘛,切吃那家新开的麻辣烫,听别个说味道硬是霸道得很,好吃到不摆了,去晚了还得排队!", + ), + # Dialogue input textbox + dialogue_text_input_label=dict( + en="Dialogue Text Input", + zh="合成文本输入", + ), + dialogue_text_input_placeholder=dict( + en="[S1]text[S2]text[S1]text...", + zh="[S1]文本[S2]文本[S1]文本...", + ), + # Generate button + generate_btn_label=dict( + en="Generate Audio", + zh="合成", + ), + # Generated audio + generated_audio_label=dict( + en="Generated Dialogue Audio", + zh="合成的对话音频", + ), + # Warining1: invalid text for prompt + warn_invalid_spk1_prompt_text=dict( + en='Invalid speaker 1 prompt text, should not be empty and strictly follow: "xxx"', + zh='说话人 1 参考文本不合规,不能为空,格式:"xxx"', + ), + # warn_invalid_spk1_prompt_cot_text=dict( + # en='Invalid speaker 1 prompt cot text, should not be empty and strictly follow: "[S1]xxx"', + # zh='说话人 1 参考文本不合规,格式:"[S1]xxx"', + # ), + warn_invalid_spk2_prompt_text=dict( + en='Invalid speaker 2 prompt text, should strictly follow: "[S2]xxx"', + zh='说话人 2 参考文本不合规,格式:"[S2]xxx"', + ), + # Warining2: invalid text for dialogue input + warn_invalid_dialogue_text=dict( + en='Invalid dialogue input text, should strictly follow: "[S1]xxx[S2]xxx..."', + zh='对话文本输入不合规,格式:"[S1]xxx[S2]xxx..."', + ), + # Warining3: incomplete prompt info + warn_incomplete_prompt=dict( + en="Please provide prompt audio and text for both speaker 1 and speaker 2", + zh="请提供说话人 1 与说话人 2 的参考语音与参考文本", + ), +) + + +global_lang: Literal["zh", "en"] = "zh" + +def i18n(key): + # (保持不变) + global global_lang + return _i18n_key2lang_dict[key][global_lang] + +def check_monologue_text(text: str, prefix: str = None) -> bool: + text = text.strip() + # Check speaker tags + if prefix is not None and (not text.startswith(prefix)): + return False + # Remove prefix + if prefix is not None: + text = text.removeprefix(prefix) + text = text.strip() + # If empty? + if len(text) == 0: + return False + return True + +def check_dialect_prompt_cot_text(text: str, prefix: str = None) -> bool: + text = text.strip() + # Check COT prefix tags + if prefix is not None and (not text.startswith(prefix)): + return False + text = text.strip() + # If empty? + if len(text) == 0: + return False + return True + +def check_dialogue_text(text_list: List[str]) -> bool: + if len(text_list) == 0: + return False + for text in text_list: + if not ( + check_monologue_text(text, "[S1]") + or check_monologue_text(text, "[S2]") + or check_monologue_text(text, "[S3]") + or check_monologue_text(text, "[S4]") + ): + return False + return True + +def process_single(target_text_list, prompt_wav_list, prompt_text_list, use_prompt_cot, prompt_cot_text_list): + spks, texts = [], [] + for target_text in target_text_list: + pattern = r'(\[S[1-9]\])(.+)' + match = re.match(pattern, target_text) + text, spk = match.group(2), int(match.group(1)[2])-1 + spks.append(spk) + texts.append(text) + + global dataset + dataitem = {"key": "001", "prompt_text": prompt_text_list, "prompt_wav": prompt_wav_list, + "text": texts, "spk": spks, } + if use_prompt_cot: + dataitem.update({ + "prompt_cot_text": prompt_cot_text_list + }) + dataset.update_datasource( + [ + dataitem + ] + ) + + # assert one data only; + data = dataset[0] + prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(data["log_mel"]) # [B, num_mels=128, T] + spk_emb_for_flow = torch.tensor(data["spk_emb"]) + prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(data["mel"], batch_first=True, padding_value=0) # [B, T', num_mels=80] + prompt_mels_lens_for_flow = torch.tensor(data['mel_len']) + text_tokens_for_llm = data["text_tokens"] + prompt_text_tokens_for_llm = data["prompt_text_tokens"] + spk_ids = data["spks_list"] + sampling_params = SamplingParams(use_ras=True,win_size=25,tau_r=0.2) + infos = [data["info"]] + processed_data = { + "prompt_mels_for_llm": prompt_mels_for_llm, + "prompt_mels_lens_for_llm": prompt_mels_lens_for_llm, + "prompt_text_tokens_for_llm": prompt_text_tokens_for_llm, + "text_tokens_for_llm": text_tokens_for_llm, + "prompt_mels_for_flow_ori": prompt_mels_for_flow, + "prompt_mels_lens_for_flow": prompt_mels_lens_for_flow, + "spk_emb_for_flow": spk_emb_for_flow, + "sampling_params": sampling_params, + "spk_ids": spk_ids, + "infos": infos, + "use_prompt_cot": use_prompt_cot, + } + if use_prompt_cot: + processed_data.update({ + "prompt_cot_text_tokens_for_llm": data["prompt_cot_text_tokens"], + "prompt_cot_prefix": data["prompt_cot_prefix"], + }) + return processed_data + +@spaces.GPU +def dialogue_synthesis_function( + target_text: str, + spk1_prompt_text: str | None = "", + spk1_prompt_audio: str | None = None, + spk1_prompt_cot_text: str | None = "", + spk2_prompt_text: str | None = "", + spk2_prompt_audio: str | None = None, + spk2_prompt_cot_text: str | None = "", + seed: int = 1988, # <-- seed 参数保留 +): + # ================== 设置随机种子 ================== + seed = int(seed) + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + # ================================================ + + # Check prompt info + target_text_list: List[str] = re.findall(r"(\[S[0-9]\][^\[\]]*)", target_text) + target_text_list = [text.strip() for text in target_text_list] + if not check_dialogue_text(target_text_list): + gr.Warning(message=i18n("warn_invalid_dialogue_text")) + return None + + # Go synthesis + progress_bar = gr.Progress(track_tqdm=True) + prompt_wav_list = [spk1_prompt_audio, spk2_prompt_audio] + prompt_text_list = [spk1_prompt_text, spk2_prompt_text] + use_prompt_cot = spk1_prompt_cot_text.strip()!="" or spk2_prompt_cot_text.strip()!="" + prompt_cot_text_list = [spk1_prompt_cot_text, spk2_prompt_cot_text] + data = process_single( + target_text_list, + prompt_wav_list, + prompt_text_list, + use_prompt_cot, + prompt_cot_text_list, + ) + + results_dict = model.forward_longform( + **data + ) + target_audio = None + for i in range(len(results_dict['generated_wavs'])): + if target_audio is None: + target_audio = results_dict['generated_wavs'][i] + else: + target_audio = torch.concat([target_audio, results_dict['generated_wavs'][i]], axis=1) + return (24000, target_audio.cpu().squeeze(0).numpy()) + + +def render_interface() -> gr.Blocks: + with gr.Blocks(title="SoulX-Podcast", theme=gr.themes.Default()) as page: + # ======================== UI ======================== + with gr.Row(): + lang_choice = gr.Radio( + choices=["中文", "English"], + value="中文", + label="Display Language/显示语言", + type="index", + interactive=True, + scale=3, + ) + seed_input = gr.Number( + label="Seed (种子)", + value=1988, + step=1, + interactive=True, + scale=1, + ) + + with gr.Row(): + # ==== Speaker1 Prompt ==== + with gr.Column(scale=1): + with gr.Group(visible=True) as spk1_prompt_group: + spk1_prompt_audio = gr.Audio( + label=i18n("spk1_prompt_audio_label"), + type="filepath", + editable=False, + interactive=True, + ) + spk1_prompt_text = gr.Textbox( + label=i18n("spk1_prompt_text_label"), + placeholder=i18n("spk1_prompt_text_placeholder"), + lines=3, + ) + spk1_prompt_cot_text = gr.Textbox( + label=i18n("spk1_prompt_cot_text_label"), + placeholder=i18n("spk1_prompt_cot_text_placeholder"), + value="", + lines=3, + ) + # ==== Speaker2 Prompt ==== + with gr.Column(scale=1, visible=True): + with gr.Group(visible=True) as spk2_prompt_group: + spk2_prompt_audio = gr.Audio( + label=i18n("spk2_prompt_audio_label"), + type="filepath", + editable=False, + interactive=True, + ) + spk2_prompt_text = gr.Textbox( + label=i18n("spk2_prompt_text_label"), + placeholder=i18n("spk2_prompt_text_placeholder"), + lines=3, + ) + spk2_prompt_cot_text = gr.Textbox( + label=i18n("spk2_prompt_cot_text_label"), + placeholder=i18n("spk2_prompt_cot_text_placeholder"), + value="", + lines=3, + ) + # ==== Text input ==== + with gr.Column(scale=2): + with gr.Row(): + dialogue_text_input = gr.Textbox( + label=i18n("dialogue_text_input_label"), + placeholder=i18n("dialogue_text_input_placeholder"), + lines=18, + ) + + # Generate button + with gr.Row(): + generate_btn = gr.Button( + value=i18n("generate_btn_label"), + variant="primary", + scale=3, + size="lg", + ) + + # Long output audio + generate_audio = gr.Audio( + label=i18n("generated_audio_label"), + interactive=False, + ) + + with gr.Row(): + inputs_for_examples = [ + spk1_prompt_audio, + spk1_prompt_text, + spk1_prompt_cot_text, + spk2_prompt_audio, + spk2_prompt_text, + spk2_prompt_cot_text, + dialogue_text_input, + ] + + example_headers = [ + "S1 音频", "S1 文本", "S1 COT", + "S2 音频", "S2 文本", "S2 COT", + "对话内容" + ] + + gr.Examples( + examples=EXAMPLES_LIST, + inputs=inputs_for_examples, + label="播客模板示例 (点击加载)", + examples_per_page=5, + ) + + # ======================== Action ======================== + def _change_component_language(lang): + global global_lang + global_lang = ["zh", "en"][lang] + return [ + + # spk1_prompt_{audio,text,prompt_cot_text} + gr.update(label=i18n("spk1_prompt_audio_label")), + gr.update( + label=i18n("spk1_prompt_text_label"), + placeholder=i18n("spk1_prompt_text_placeholder"), + ), + gr.update( + label=i18n("spk1_prompt_cot_text_label"), + placeholder=i18n("spk1_prompt_cot_text_placeholder"), + ), + # spk2_prompt_{audio,text} + gr.update(label=i18n("spk2_prompt_audio_label")), + gr.update( + label=i18n("spk2_prompt_text_label"), + placeholder=i18n("spk2_prompt_text_placeholder"), + ), + gr.update( + label=i18n("spk2_prompt_cot_text_label"), + placeholder=i18n("spk2_prompt_cot_text_placeholder"), + ), + # dialogue_text_input + gr.update( + label=i18n("dialogue_text_input_label"), + placeholder=i18n("dialogue_text_input_placeholder"), + ), + # generate_btn + gr.update(value=i18n("generate_btn_label")), + # generate_audio + gr.update(label=i18n("generated_audio_label")), + ] + + lang_choice.change( + fn=_change_component_language, + inputs=[lang_choice], + outputs=[ + spk1_prompt_audio, + spk1_prompt_text, + spk1_prompt_cot_text, + spk2_prompt_audio, + spk2_prompt_text, + spk2_prompt_cot_text, + dialogue_text_input, + generate_btn, + generate_audio, + ], + ) + + # Generate button click Action + generate_btn.click( + fn=dialogue_synthesis_function, + inputs=[ + dialogue_text_input, + spk1_prompt_text, + spk1_prompt_audio, + spk1_prompt_cot_text, + spk2_prompt_text, + spk2_prompt_audio, + spk2_prompt_cot_text, + seed_input, + ], + outputs=[generate_audio], + ) + return page + + +# ================================================ +# Options +# ================================================ +def get_args(): + parser = ArgumentParser() + parser.add_argument('--model_path', + required=True, + type=str, + help='model path') + parser.add_argument('--llm_engine', + type=str, + default="hf", + help='model execute engine') + parser.add_argument('--fp16_flow', + action='store_true', + help='enable fp16 flow') + parser.add_argument('--seed', + type=int, + default=1988, + help='random seed for generation') + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + + # Initiate model + hf_config = SoulXPodcastLLMConfig.from_initial_and_json( + initial_values={"fp16_flow": args.fp16_flow}, + json_file=f"{args.model_path}/soulxpodcast_config.json") + + # Compatible with the absence of a VLLM installation + llm_engine = args.llm_engine + if llm_engine == "vllm": + if not importlib.util.find_spec("vllm"): + llm_engine = "hf" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + tqdm.write(f"[{timestamp}] - [WARNING]: No install VLLM, switch to hf engine.") + config = Config(model=args.model_path, enforce_eager=True, llm_engine=llm_engine, + hf_config=hf_config) + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + initiate_model(config) + print("[INFO] SoulX-Podcast loaded") + # UI + page = render_interface() + page.queue() + page.launch() + # page.launch(share=False) \ No newline at end of file diff --git a/gitattributes b/gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..da8131c2907f69157b7b5ed5f96ba594868cc5b2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +librosa +numpy +scipy +s3tokenizer +diffusers +torch==2.7.1 +torchaudio==2.7.1 +triton>=3.0.0 +transformers==4.57.1 +accelerate==1.10.1 +onnxruntime +onnxruntime-gpu +einops +gradio \ No newline at end of file diff --git a/soulxpodcast/__init__.py b/soulxpodcast/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/soulxpodcast/__pycache__/__init__.cpython-311.pyc b/soulxpodcast/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4299881a5e0726c18bbb88e4e4f3a37ff4c32f93 Binary files /dev/null and b/soulxpodcast/__pycache__/__init__.cpython-311.pyc differ diff --git a/soulxpodcast/__pycache__/__init__.cpython-312.pyc b/soulxpodcast/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d673353dabd327344519d06921fc7af78f9f9cd8 Binary files /dev/null and b/soulxpodcast/__pycache__/__init__.cpython-312.pyc differ diff --git a/soulxpodcast/__pycache__/config.cpython-311.pyc b/soulxpodcast/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee058ed3782655b26dca70afa2b58393bb6447b8 Binary files /dev/null and b/soulxpodcast/__pycache__/config.cpython-311.pyc differ diff --git a/soulxpodcast/__pycache__/config.cpython-312.pyc b/soulxpodcast/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2f721c2dfcb928a67e8315fde18da2e063493a8 Binary files /dev/null and b/soulxpodcast/__pycache__/config.cpython-312.pyc differ diff --git a/soulxpodcast/cli/__pycache__/soulxpodcast.cpython-311.pyc b/soulxpodcast/cli/__pycache__/soulxpodcast.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..282b394236e8e8bb1b27c364a2b52ae238d847b5 Binary files /dev/null and b/soulxpodcast/cli/__pycache__/soulxpodcast.cpython-311.pyc differ diff --git a/soulxpodcast/cli/__pycache__/soulxpodcast.cpython-312.pyc b/soulxpodcast/cli/__pycache__/soulxpodcast.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af281b3eb1c73c7b9594ee248d243772af155e72 Binary files /dev/null and b/soulxpodcast/cli/__pycache__/soulxpodcast.cpython-312.pyc differ diff --git a/soulxpodcast/cli/engine_test.py b/soulxpodcast/cli/engine_test.py new file mode 100644 index 0000000000000000000000000000000000000000..518fda33af7dbed27d6d3b39a7a60672380e0272 --- /dev/null +++ b/soulxpodcast/cli/engine_test.py @@ -0,0 +1,74 @@ +import argparse +import json +import os +import random +import sys +from glob import glob +from copy import deepcopy +import numpy as np +import torch +from tqdm import tqdm +from dataclasses import fields, asdict + +from vllm import LLM +from vllm.inputs import TokensPrompt as TokensPrompt +from vllm import SamplingParams + +def set_all_random_seed(seed): + import random + import numpy as np + import os + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def get_args(): + parser = argparse.ArgumentParser(description='FlashCosyVoice') + parser.add_argument('--model_path', + required=True, + type=str, + help='model path') + parser.add_argument('--seed', + type=int, + default=1986, + help='random seed for generation') + args = parser.parse_args() + return args + +def main(): + args = get_args() + + os.environ["VLLM_USE_V1"] = "0" + sampling_params = SamplingParams(temperature=0.6, repetition_penalty=1.25, top_k=100, top_p=0.9, max_tokens=3000, stop_token_ids=[153478], use_ras=True, win_size=25, tau_r=0.2) + + llm = LLM(model=args.model_path, enforce_eager=True, dtype="bfloat16", ) + input = [152477, 151674, 151712, 152596, 120432, 102115, 101417, 121407, 3837, 113562, 108896, 103554, 34187, 99744, 104283, 3837, 104074, 18830, 101417, 63109, 29826, 103924, 11319, 151713, 153477, 157342, 157987, 158071, 157264, 153912, 154998, 159371, 159391, 159562, 158917, 158870, 159680, 159672, 157485, 155058, 155207, 153630, 153846, 158058, 153814, 153841, 158204, 154543, 158928, 154055, 155761, 155026, 155298, 155061, 155884, 159856, 155815, 158542, 156292, 154332, 155055, 153606, 159474, 157971, 158055, 158607, 158193, 155035, 155038, 159420, 159643, 157699, 155778, 154224, 158652, 158461, 158551, 156465, 157728, 155541, 155307, 155297, 157894, 157894, 157813, 157570, 159511, 159616, 158914, 158833, 158935, 155045, 155770, 157962, 159491, 157384, 157398, 157077, 158859, 158693, 159419, 159565, 158796, 157484, 157248, 158029, 158348, 158271, 156337, 154890, 154638, 155268, 159637, 158827, 159646, 158179, 158159, 155244, 155280, 155283, 154488, 156999, 156279, 157007, 157891, 160081, 160000, 153478, 151675, 151712, 152596, 18830, 101417, 106909, 56568, 104138, 36587, 3837, 100132, 107133, 117320, 3837, 105146, 53153, 102155, 117927, 1773, 151713, 153477, 155303, 157049, 159236, 159344, 158936, 158693, 159425, 159424, 158609, 158266, 159972, 159969, 160052, 160055, 159971, 158375, 159670, 159427, 159428, 155783, 159321, 160053, 158284, 158132, 156714, 157470, 155283, 155286, 155167, 157324, 159511, 154330, 155635, 159613, 158481, 156348, 154160, 155572, 158516, 158326, 154645, 155619, 155457, 157343, 158882, 158155, 154777, 154249, 158695, 158675, 155055, 154332, 155079, 159004, 158510, 158591, 158593, 157151, 158697, 158688, 158595, 158577, 158576, 157814, 157894, 157894, 157813, 157813, 157813, 157813, 157732, 157240, 157483, 155296, 155626, 155761, 153659, 154194, 155052, 157968, 159421, 158032, 159320, 159247, 159015, 157789, 159974, 159977, 159986, 154641, 154392, 156328, 156094, 154413, 157728, 159509, 153733, 155680, 155860, 154889, 155618, 155609, 156581, 158191, 156949, 157482, 158691, 158694, 160100, 159125, 159772, 155627, 155626, 157489, 159514, 158785, 156598, 154411, 155626, 160054, 160130, 160051, 160066, 159337, 158365, 158194, 154783, 154867, 158372, 158615, 158616, 155943, 157263, 157587, 155425, 155380, 157321, 155080, 155734, 155029, 158690, 154426, 157259, 155847, 158961, 159988, 157069, 156513, 155541, 154569, 155316, 155403, 157427, 158238, 158588, 158668, 155935, 159970, 160050, 159725, 159778, 159967, 155369, 155370, 156828, 158446, 155005, 154246, 156433, 158539, 156455, 154308, 153738, 153990, 153909, 154152, 156429, 158426, 158452, 158659, 157085, 157813, 157813, 157813, 153478, 151674, 151712, 152596, 50930, 100014, 36587, 112491, 54039, 100391, 107512, 113524, 9370, 115639, 100930, 34718, 104038, 26940, 110602, 17447, 99469, 28029, 25067, 69249, 109624, 116542, 3837, 108243, 82166, 99172, 58364, 52129, 104207, 9370, 22697, 104097, 99882, 99615, 70074, 3837, 120432, 18830, 101417, 110721, 11319, 151713, 153477, 154327, 153757, 159436, 157245, 153603, 156051, 158027, 158273, 154133, 153918, 157908, 159372, 158681, 158163, 157428, 159572, 159886, 155049, 154305, 158538, 153973, 153595, 153676, 159430, 155060, 154575, 156006, 160129, 160138, 158666, 160017, 155156, 155437, 157459, 154713, 154962, 157239, 156016, 158272, 156688, 158384, 158541, 154218, 156186, 159262, 158614, 158535, 155203, 159574, 157316, 159669, 159657, 155284, 157300, 157456, 157159, 154906, 155781, 154720, 153652, 157183, 159422, 155206, 158190, 158082, 157996, 159239, 158513, 156169, 154314, 155727, 159124, 155689, 158078, 157247, 155304, 153900, 154390, 159975, 159726, 159968, 158257, 158203, 155512, 158056, 158479, 158085, 156519, 154329, 154364, 155059, 159433, 159433, 158704, 155140, 155707, 153649, 156516, 155057, 154656, 154079, 155191, 155775, 157239, 156669, 154951, 158517, 158030, 158245, 158299, 154123, 154728, 154857, 157830, 159853, 158885, 158425, 158072, 157749, 155157, 159535, 159619, 158836, 156597, 155540, 155550, 154101, 156277, 156349, 156124, 158349, 153651, 153898, 156835, 156445, 159397, 157232, 155128, 157158, 156023, 159500, 155871, 154557, 157319, 159577, 158938, 158158, 158629, 159287, 157833, 157860, 157887, 159426, 156355, 158623, 154287, 155544, 157727, 159512, 158038, 158514, 158533, 155790, 154812, 154830, 155542, 155056, 157246, 157243, 155053, 155620, 154813, 154000, 158455, 158621, 158678, 155051, 154177, 160154, 158689, 153976, 158366, 156021, 154233, 154161, 158002, 158270, 153890, 158781, 158639, 158642, 160132, 160133, 160144, 157967, 155694, 157644, 156024, 158135, 155784, 155078, 155491, 155599, 154326, 155543, 154821, 156774, 159533, 158480, 158417, 158054, 157258, 155787, 153792, 159538, 158890, 158809, 158809, 158107, 156577, 155133, 154404, 155446, 158515, 157076, 158453, 159040, 158392, 157669, 159245, 158524, 153708, 155546, 154899, 158660, 160096, 158627, 159431, 157490, 154197, 156096, 158887, 156733, 154434, 154119, 160013, 159043, 155573, 155545, 157327, 159433, 156517, 155383, 155401, 155072, 153639, 156294, 156888, 158664, 158663, 159562, 155920, 157242, 157726, 157241, 154517, 155356, 157708, 160151, 160153, 159418, 159101, 159263, 158534, 158364, 158004, 156524, 157567, 157894, 157894, 153478, 151675, 151712, 152596, 110931, 45629, 101454, 17447, 40814, 99164, 100753, 41683, 9370, 2073, 119176, 103162, 102781, 854, 100074, 99615, 47872, 6313, 115639, 112491, 99615, 81668, 104462, 99319, 100307, 102561, 99964, 101182, 99319, 108375, 100074, 3837, 99557, 99601, 104631, 108439, 100372, 100369, 99375, 99261, 6313, 151713, 153477, 155328, 159429, 159432, 153612, 157265, 155322, 155653, 159246, 156346, 153981, 155055, 155539, 155541, 154270, 155677, 160105, 160075, 157207, 160139, 158897, 158267, 158518, 158052, 156756, 154092, 155559, 155318, 155554, 155299, 155623, 155302, 159433, 159541, 157354, 155411, 154843, 158351, 158362, 159343, 156105, 154397, 158521, 154965, 154221, 156089, 157840, 159325, 159319, 160067, 160070, 159340, 159094, 159244, 158428, 156352, 158458, 159032, 158134, 158000, 157261, 155158, 153668, 153597, 153867, 154110, 154838, 155569, 156997, 157000, 154898, 155636, 157941, 155672, 155673, 155754, 160127, 160126, 158692, 158858, 156319, 157048, 157075, 154242, 154359, 153653, 155077, 155071, 158700, 159104, 158374, 158383, 158232, 157017, 157503, 157506, 155319, 155399, 155644, 155545, 155053, 157243, 159457, 157270, 155083, 157786, 159243, 155782, 158941, 159194, 159752, 153849, 155562, 155643, 155722, 154991, 159915, 155058, 154440, 156501, 158209, 155518, 153661, 158696, 158200, 158861, 157566, 155622, 155706, 155733, 155760, 155624, 154814, 157810, 157813, 157813, 157813, 157813, 157246, 159508, 159454, 157267, 155326, 158239, 158590, 159322, 159403, 158599, 158680, 156482, 155646, 157509, 155482, 154135, 159510, 159435, 153954, 154070, 153598, 155863, 159670, 157250, 154332, 154143, 158454, 157861, 160135, 156503, 158600, 159935, 159773, 159693, 159774, 157909, 159267, 155055, 154602, 154062, 158207, 156364, 156436, 156481, 156510, 154839, 157394, 157557, 159023, 159103, 156036, 155640, 155400, 155321, 155563, 155626, 155545, 159433, 159460, 158731, 154357, 155464, 155515, 157481, 159264, 153999, 153711, 153747, 158561, 159290, 156310, 158210, 158617, 158543, 158679, 154309, 155758, 154323, 153975, 159581, 158214, 159671, 154578, 155565, 155645, 154168, 154264, 160155, 160074, 160073, 159423, 155047, 155707, 155269, 157157, 156347, 153657, 154356, 153648, 159209, 157021, 158506, 158587, 158416, 155965, 159532, 157346, 159983, 155809, 158212, 159193, 159753, 157008, 154587, 155561, 155551, 157729, 157813, 157813, 157813, 157813, 157489, 159511, 159457, 156541, 155461, 154408, 153806, 153744, 156590, 159046, 158254, 158108, 158354, 158353, 156177, 155061, 153626, 159448, 159127, 159976, 158079, 158706, 155844, 154624, 159240, 156511, 159913, 154815, 154818, 156492, 158686, 158275, 159077, 160049, 159256, 157800, 157798, 158596, 159157, 159346, 158641, 155780, 155697, 158759, 158110, 154159, 154437, 155547, 154089, 156384, 158687, 156500, 157084, 159184, 159187, 154817, 155726, 157940, 157939, 158684, 158131, 158057, 153644, 157507, 155564, 158434, 157408, 156676, 158156, 155895, 155126, 157395, 153671, 156420, 155868, 156368, 160019, 153850, 156171, 158685, 157886, 156509, 159132, 155274, 155277, 155196, 155168, 155141, 155138, 155626, 153478, 151674, 151712, 152596, 102177, 73670, 99877, 2073, 102274, 108243, 854, 99257, 100896, 3837, 100345, 26940, 107513, 99815, 85361, 23656, 25067, 105636, 3837, 58364, 52129, 2073, 100917, 108471, 101187, 99243, 100859, 854, 2073, 106957, 99476, 23305, 44729, 854, 100001, 100175, 107600, 9370, 119305, 100399, 70074, 6313, 151713, 153477, 157543, 155855, 158784, 158780, 158703, 155796, 156531, 155402, 153838, 156187, 156287, 158109, 158373, 156321, 154403, 157057, 156088, 160152, 159421, 155290, 157309, 157315, 157401, 158940, 156753, 159485, 159643, 158920, 154474, 154650, 154910, 159770, 158318, 158507, 154134, 153882, 153618, 159449, 156940, 154084, 156106, 159196, 159350, 158624, 154305, 154322, 154978, 154267, 160047, 159398, 155050, 154897, 154291, 155043, 154917, 159854, 156859, 158278, 157236, 154116, 154842, 160048, 157942, 157731, 159186, 158555, 159286, 157822, 160081, 157894, 157894, 157894, 157894, 157813, 157813, 157813, 154570, 153802, 153624, 158075, 160099, 159091, 159911, 159912, 156996, 153674, 155410, 154141, 153903, 154278, 157185, 159121, 158884, 157066, 158281, 154971, 153669, 159204, 159367, 159850, 156449, 154086, 154833, 154188, 156365, 159352, 158633, 159833, 159832, 159589, 157447, 154451, 157238, 157965, 153835, 154870, 158274, 154888, 155376, 155605, 156817, 153627, 159513, 158868, 156141, 155331, 155384, 156760, 159433, 159433, 159433, 159433, 155137, 157837, 160102, 160129, 157954, 155700, 154968, 159588, 159183, 158189, 158783, 155799, 153864, 156068, 158188, 159163, 156967, 158192, 157976, 156536, 155320, 159253, 154647, 153873, 153603, 158229, 156320, 157039, 158444, 158860, 158546, 157104, 155725, 154298, 159593, 156114, 153819, 154384, 157405, 159437, 159995, 154104, 155724, 155716, 155755, 154646, 154863, 154374, 157746, 159045, 158291, 159650, 157444, 159181, 158202, 153600, 155117, 157313, 157393, 155811, 159284, 160016, 159804, 159910, 158197, 158137, 155795, 157262, 155347, 159980, 157556, 156585, 153663, 155114, 156518, 158704, 155788, 155221, 155113, 156739, 153789, 155852, 159000, 154132, 156087, 158081, 155194, 154621, 156025, 154081, 155613, 154137, 158186, 158996, 159220, 158286, 153894, 153654, 158026, 158597, 156184, 158619, 158651, 159409, 155164, 159643, 156703, 155210, 157314, 157977, 156339, 154862, 154861, 154727, 155568, 155574, 155007, 158688, 156280, 158536, 158581, 158402, 156651, 159643, 154471, 154677, 156288, 159044, 155555, 157894, 157894, 157813, 153478, 151675, 151712, 152596, 99936, 30534, 104609, 106594, 99882, 99615, 70074, 9370, 101485, 99499, 57566, 6313, 102245, 101325, 99164, 45861, 102504, 9370, 2073, 100268, 93, 102634, 44729, 93, 121407, 93, 119921, 93, 99800, 101953, 93, 33590, 99601, 49187, 36407, 100132, 104666, 101941, 6313, 151713, 153477, 154258, 156489, 155054, 154331, 154349, 159775, 157831, 158516, 156148, 158443, 158165, 155817, 153636, 155074, 155419, 155329, 159433, 156517, 154816, 159235, 156015, 154896, 154230, 154948, 158515, 157222, 154275, 155540, 155567, 159914, 155971, 158515, 158608, 160071, 157884, 157155, 154320, 155039, 157807, 156754, 155323, 157030, 158347, 156504, 154296, 157914, 157590, 157617, 157724, 159668, 158198, 158162, 158001, 156533, 159453, 157266, 155105, 155330, 157246, 155086, 154870, 158111, 156427, 155976, 157001, 154098, 154206, 158669, 159370, 157906, 159266, 157244, 153927, 153675, 158753, 159969, 157060, 153660, 155315, 159776, 154633, 158025, 157998, 156054, 156027, 153840, 154083, 154595, 155299, 157240, 154412, 154826, 157642, 157480, 159664, 158206, 155940, 155180, 155103, 155102, 155183, 155200, 159665, 157725, 155295, 155441, 155479, 155477, 155898, 158445, 158427, 158319, 159047, 157823, 157813, 157813, 157813, 157813, 157813, 157813, 157813, 157813, 155398, 159445, 157970, 153623, 158512, 156342, 158670, 158643, 160101, 159369, 155011, 157078, 159751, 157591, 155407, 154627, 156133, 158542, 154178, 154302, 153982, 158269, 158682, 158321, 159973, 158511, 158698, 159679, 159103, 158311, 159695, 155483, 158516, 155869, 156526, 157494, 154821, 154911, 155314, 155838, 158322, 158241, 158223, 158303, 158222, 155324, 155570, 155300, 155545, 155296, 157483, 157483, 155539, 155707, 157912, 157666, 156452, 158650, 158647, 158648, 158568, 158571, 156357, 154170, 154179, 154845, 154844, 155654, 155545, 155626, 157813, 157486, 159427, 157246, 155383, 154843, 158271, 156175, 158696, 155979, 156600, 153635, 159451, 155428, 159982, 159985, 159744, 159096, 158043, 158034, 158115, 158087, 158087, 158978, 157495, 157733, 157813, 157813, 157813, 157813, 157813, 157813, 155302, 159430, 159430, 157243, 155059, 155707, 153892, 158195, 159653, 159654, 157467, 157476, 157397, 159582, 154398, 158139, 158166, 158112, 155925, 156168, 156249, 154071, 154719, 155439, 155358, 155087, 155383, 157813, 157732, 157732, 157813, 157813, 157813, 157813, 157243, 159430, 156514, 155377, 157894, 154297, 158618, 159104, 158943, 158457, 156270, 156351, 154167, 154893, 155595, 155325, 155570, 155545, 155626, 157813, 157813, 157813, 157813, 157813, 157813, 157813, 155626, 159973, 159248, 158295, 159269, 158544, 156393, 158661, 158604, 158995, 160131, 156340, 158294, 159024, 159105, 158862, 159618, 159621, 159645, 159630, 159629, 159548, 157579, 159766, 157498, 157741, 157813, 157813, 157813, 157813, 157894, 155707, 157786, 158997, 157969, 158950, 159590, 157329, 153684, 154035, 155862, 155434, 153901, 156837, 156273, 156354, 156381, 158649, 158676, 158657, 158656, 159388, 160010, 157742, 157732, 157732, 157732, 157813, 157813, 157813, 157813, 157813, 157726, 159670, 159697, 157510, 157510, 157483, 157483, 157726, 157483, 157483, 157726, 157483, 157483, 157486, 157732, 160000, 160000, 160000, 160000, 160000, 159919, 159919, 159919, 159919, 159919, 159919, 157732, 157732, 159217, 158243, 158489, 153889, 154386, 154353, 156529, 157588, 156097, 158051, 157974, 153752, 155365, 157975, 156112, 153729, 155076, 157752, 157881, 156099, 158133, 155865, 153681, 153735, 154464, 155084, 155327, 155056, 157243, 157240, 157240, 157240, 157813, 160000, 160000, 157813, 157813, 157813, 157813, 157813, 157894, 155593, 158105, 158592, 156426, 154239, 156507, 157233, 158725, 158671, 158509, 157799, 157071, 155856, 153750, 153993, 153741, 153984, 154389, 155351, 155573, 155545, 155545, 155644, 157426, 156695, 158371, 158560, 158264, 157508, 157589, 155509, 156085, 156103, 155934, 154410, 155922, 158265, 159237, 157805, 157806, 156834, 156807, 158994, 159830, 159829, 158939, 156750, 153825, 153816, 153705, 154461, 155112, 155354, 155327, 155545, 157813, 157813, 153478, 151674, 151712, 152596, 110602, 17447, 99469, 99354, 75882, 29635, 18947, 2073, 70074, 85254, 101904, 23836, 33590, 99258, 104090, 111077, 18800, 101044, 106063, 18397, 115639, 119754, 46553, 3837, 45181, 99391, 22697, 9370, 101780, 17523, 99882, 99615, 104188, 99530, 22697, 9370, 120728, 121909, 99293, 115807, 101182, 6313, 151713, 153477, 156815, 158031, 156031, 154620, 154128, 159365, 157423, 158399, 158173, 158960, 157260, 159535, 156730, 157323, 155541, 154824, 156290, 156268, 156367, 158268, 153732, 159182, 158447, 159131, 159591, 157404, 155217, 154542, 155488, 153760, 155139, 155061, 158409, 158201, 158914, 158836, 155993, 156081, 154883, 154126, 156414, 153678, 156542, 159643, 158839, 153743, 154191, 155558, 159534, 157777, 159010, 159345, 158096, 159933, 155481, 155353, 158380, 156283, 159316, 158668, 158606, 158220, 155061, 154341, 155159, 156301, 159164, 159622, 155176, 155538, 154566, 158211, 155115, 155627, 158947, 159676, 159433, 159433, 159433, 155222, 155392, 155294, 155136, 157002, 155802, 156374, 157156, 158932, 158683, 158613, 154903, 153919, 156034, 158529, 156324, 156297, 155919, 155275, 153895, 159516, 157491, 154556, 155530, 155046, 155052, 157882, 157951, 158187, 158435, 158138, 159753, 157809, 159834, 158891, 158806, 158915, 155902, 156162, 156405, 155767, 156009, 158355, 156411, 154080, 155500, 159562, 157915, 154174, 155306, 155058, 156123, 159155, 159884, 153828, 153948, 156392, 158620, 158448, 156267, 158039, 158672, 158356, 156498, 155025, 159368, 155443, 158024, 159515, 156762, 153823, 155524, 158616, 156060, 153621, 155617, 156823, 154881, 153972, 158798, 159041, 155582, 155626, 157894, 157894, 157813, 159757, 159430, 159430, 159430, 159430, 155383, 157759, 158350, 156358, 156483, 157101, 155480, 157795, 159242, 158106, 159270, 158625, 158674, 159167, 159643, 158914, 156638, 157475, 155127, 155814, 153742, 156082, 159265, 158665, 159424, 159425, 156274, 156454, 160154, 159176, 154150, 158290, 154407, 157734, 158630, 160093, 158385, 158787, 156033, 153736, 153628, 154096, 155349, 158718, 157585, 155837, 159996, 157083, 156462, 155937, 156428, 155590, 155591, 154161, 154629, 158865, 158930, 158914, 156655, 159662, 153642, 155892, 154957, 154243, 156844, 158184, 156014, 156584, 158436, 158696, 158282, 159081, 158488, 156348, 155261, 154722, 156492, 158565, 156506, 154987, 154294, 160155, 159424, 155691, 155708, 157813, 153478, 151675, 151712, 152596, 102177, 99360, 91777, 100569, 17340, 101376, 102073, 22697, 70074, 104387, 12857, 74393, 41505, 120965, 101240, 120965, 102565, 97907, 102138, 34718, 70074, 91956, 99662, 99318, 121657, 44729, 97907, 99318, 100893, 70074, 3837, 100132, 75606, 110261, 104754, 6313, 151713, 153477] + outputs = llm.generate(TokensPrompt(prompt_token_ids=input), + sampling_params, + use_tqdm=False)[0].outputs[0].token_ids + print(outputs) + # files = glob(f"{args.data_list}/*_result.json") + # files.sort() + # for file in files: + # with open(file) as fin: + # test_sets = json.load(fin) + # for test_set in test_sets: + # input = test_set["input"] + # set_all_random_seed(args.seed) + # llm_outputs = model.llm.generate(input, sampling_params)['token_ids'] + # set_all_random_seed(args.seed) + # import pdb;pdb.set_trace() + # outputs = llm.generate(TokensPrompt(prompt_token_ids=input), + # VllmSamplingParams(**asdict(sampling_params)), + # use_tqdm=False)[0].outputs[0].token_ids + # print(llm_outputs) + # print(outputs) + # print("=========") + # import pdb;pdb.set_trace() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/soulxpodcast/cli/soulxpodcast.py b/soulxpodcast/cli/soulxpodcast.py new file mode 100644 index 0000000000000000000000000000000000000000..622cc8ef0c72361010417f3535818094f32a1a07 --- /dev/null +++ b/soulxpodcast/cli/soulxpodcast.py @@ -0,0 +1,273 @@ +import argparse +import json +import os +import random +import sys +import time +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +import importlib.util + +import numpy as np +import onnxruntime +import s3tokenizer +import torch +import torchaudio +import torchaudio.compliance.kaldi as kaldi +from torch.utils.data import DataLoader, Dataset, SequentialSampler +from tqdm import tqdm + +from soulxpodcast.config import Config, SoulXPodcastLLMConfig, SamplingParams +from soulxpodcast.models.soulxpodcast import SoulXPodcast +from soulxpodcast.utils.dataloader import PodcastDataset +from soulxpodcast.utils.audio import mel_spectrogram + + +def set_all_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def save_file_async( + wav, uttid, + info, + is_sub, +): + """Save audio asynchronously.""" + try: + parentdir = f"{os.path.dirname(info['wav'])}" + basename = os.path.basename(info["wav"]).split(".")[0] + if is_sub: + parentdir = f"{parentdir}/individual_clips" + os.makedirs(parentdir, exist_ok=True) + if wav is not None: + wav = wav.cpu() + torchaudio.save(f'{parentdir}/{uttid}.wav', wav, 24000) + duration = wav.shape[-1] / 24000.0 + else: + duration = 0.0 + if not is_sub: + with open(f"{parentdir}/{basename}.json", "w") as f: + json.dump(info, f, ensure_ascii=False, indent=4) + return duration + except Exception as e: + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + tqdm.write(f"[{timestamp}] - [ERROR] - Error saving audio {info.get('key', 'unknown')}: {e}") + return 0.0 + +def collate_fn(batch): + assert len(batch) == 1 + data = batch[0] + + # prepare prompt mels for llm, spk_emb + prompt mel for flow; + prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(data["log_mel"]) # [B, num_mels=128, T] + spk_emb_for_flow = torch.tensor(data["spk_emb"]) + prompt_mels_for_flow = data["mel"] + + # prepare text + spk for llm; + text_tokens_for_llm = data["text_tokens"] + prompt_text_tokens_for_llm = data["prompt_text_tokens"] + spk_ids = data["spks_list"] + sampling_params = SamplingParams() + infos = [data["info"]] + + processed_data = { + "prompt_mels_for_llm": prompt_mels_for_llm, + "prompt_mels_lens_for_llm": prompt_mels_lens_for_llm, + "prompt_text_tokens_for_llm": prompt_text_tokens_for_llm, + "text_tokens_for_llm": text_tokens_for_llm, + "prompt_mels_for_flow_ori": prompt_mels_for_flow, + "spk_emb_for_flow": spk_emb_for_flow, + "sampling_params": sampling_params, + "spk_ids": spk_ids, + "infos": infos, + } + + if data.get("use_prompt_cot", False): + processed_data.update({ + "use_prompt_cot": True, + "prompt_cot_text_tokens_for_llm": data["prompt_cot_text_tokens"], + "prompt_cot_prefix": data["prompt_cot_prefix"], + }) + return processed_data + + +def get_args(): + parser = argparse.ArgumentParser(description='FlashCosyVoice') + parser.add_argument('--model_path', + required=True, + type=str, + help='model path') + parser.add_argument('--data_list', + required=True, + type=str, + help='data list') + parser.add_argument('--num_workers', + type=int, + default=4, + help='workers for dataloader') + parser.add_argument('--prefetch', + type=int, + default=5, + help='prefetch for dataloader') + parser.add_argument('--llm_engine', + type=str, + default="hf", + help='model execute engine') + parser.add_argument('--fp16_flow', + action='store_true', + help='enable fp16 flow') + parser.add_argument('--seed', + type=int, + default=1986, + help='random seed for generation') + parser.add_argument('--save_intermediate', + action='store_true', + help='enable save intermediate result in long form.') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + + assert (torch.cuda.is_available()) + hf_config = SoulXPodcastLLMConfig.from_initial_and_json( + initial_values={"fp16_flow": args.fp16_flow}, + json_file=f"{args.model_path}/soulxpodcast_config.json") + + # Compatible with the absence of a VLLM installation + llm_engine = args.llm_engine + if llm_engine == "vllm": + if not importlib.util.find_spec("vllm"): + llm_engine = "hf" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + tqdm.write(f"[{timestamp}] - [WARNING]: No install VLLM, switch to hf engine.") + + config = Config(model=args.model_path, enforce_eager=True, llm_engine=llm_engine, + hf_config=hf_config) + model = SoulXPodcast(config) + + set_all_random_seed(args.seed) + + dataset = PodcastDataset(model.llm.tokenizer, args.data_list, config) + sampler = SequentialSampler(dataset,) + dataloader = DataLoader(dataset, batch_size=1, num_workers=args.num_workers, pin_memory=True, + sampler=sampler, shuffle=False, prefetch_factor=args.prefetch, collate_fn=collate_fn) + total_steps = len(dataset) + + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + tqdm.write(f"[{timestamp}] - [INFO] - {args}") + progress_bar = tqdm(total=total_steps, desc="Processing samples", unit="wav", + position=0, leave=True, dynamic_ncols=True) + + cpu_counts = os.cpu_count() + executor = ThreadPoolExecutor(max_workers=min(args.num_workers, cpu_counts // 2)) + + pending_futures = [] + dataloader_iter = iter(dataloader) + succeed_duration = 0.01 # avoid division by zero + start_time = time.time() + estimated_total_wavs = 0 + succeed_wavs = 0 + failed_wavs = 0 + last_print_time = start_time + + while True: + try: + batch = next(dataloader_iter) + + if len(batch['infos']) == 0: + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + tqdm.write(f"[{timestamp}] - [WARNING]: No valid batch found, skipping this batch...") + continue + + results_dict = model.forward_longform(**batch) + + estimated_total_wavs += len(results_dict['generated_wavs']) + + uttid = batch['infos'][0]['key'] + result = None + for i in range(len(results_dict['generated_wavs'])): + is_sub = True + if args.save_intermediate: + future = executor.submit( + save_file_async, results_dict['generated_wavs'][i], + f"{uttid}_turn_{str(i).zfill(2)}", batch['infos'][0].copy(), is_sub + ) + pending_futures.append(future) + if result is None: + result = results_dict['generated_wavs'][i] + else: + result = torch.concat([result, results_dict['generated_wavs'][i]], axis=1) + future = executor.submit( + save_file_async, result, + f"{uttid}", batch['infos'][0].copy(), False + ) + pending_futures.append(future) + completed_futures = [] + for future in pending_futures: + if future.done(): + try: + duration = future.result() + succeed_duration += duration + succeed_wavs += 1 + except Exception as e: + failed_wavs += 1 + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + tqdm.write(f"[{timestamp}] - [ERROR]: Error in async save task: {e}") + completed_futures.append(future) + + for future in completed_futures: + pending_futures.remove(future) + + update_n = 1 + if progress_bar.n + update_n > progress_bar.total: + progress_bar.update(progress_bar.total - progress_bar.n) + else: + progress_bar.update(update_n) + + current_time = time.time() + if current_time - last_print_time >= 120: + elapsed_time = current_time - start_time + avg_duration = succeed_duration / succeed_wavs if succeed_wavs > 0 else 0 + estimated_total_duration = avg_duration * estimated_total_wavs + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + tqdm.write(f"[{timestamp}] - [INFO]: Estimated total wavs: {estimated_total_wavs} ({estimated_total_wavs - succeed_wavs} pending to save), Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Estimated total duration: {estimated_total_duration:.2f}s ({estimated_total_duration / 3600:.2f} h), Elapsed time: {elapsed_time:.2f}s") # noqa + last_print_time = current_time + except StopIteration: + break + except Exception as e: + failed_wavs += 1 + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + tqdm.write(f"[{timestamp}] - [ERROR]: Error in main loop: {e}") + continue + + total_time = time.time() - start_time + + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + tqdm.write(f"[{timestamp}] - [INFO] - Waiting for {len(pending_futures)} pending save tasks to complete...") + + for future in pending_futures: + try: + duration = future.result(timeout=60) + succeed_duration += duration + succeed_wavs += 1 + except Exception as e: + failed_wavs += 1 + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + tqdm.write(f"[{timestamp}] - [ERROR]: Error in final async save task: {e}") + executor.shutdown(wait=True) + + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + tqdm.write(f"[{timestamp}] - [INFO]: All async save tasks completed.") + progress_bar.close() + + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + tqdm.write(f"[{timestamp}] - [INFO]: Final Report - Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Total duration: {succeed_duration:.2f}s ({succeed_duration / 3600:.2f} h).") # noqa + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/soulxpodcast/config.py b/soulxpodcast/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8c58f0c1e81895d78f62f353dba62378800c807d --- /dev/null +++ b/soulxpodcast/config.py @@ -0,0 +1,141 @@ +import os +from dataclasses import dataclass, field, fields, is_dataclass, asdict +from typing import Any, Dict, List, Optional +from pathlib import Path +import json + +import torch +from transformers import AutoConfig +from transformers import PretrainedConfig + +@dataclass +class SoulXPodcastLLMConfig: + architectures: list[str] = field(default_factory=lambda: ["Qwen3ForCausalLM"]) + attention_dropout: float = 0.0 + bos_token_id: int = 151643 + eos_token_id: int = 151675 # speech eos + hidden_act: str = "silu" + hidden_size: int = 2048 + initializer_range: float = 0.02 + intermediate_size: int = 6144 + max_position_embeddings: int = 40960 + max_window_layers: int = 28 + model_type: str = "qwen3" + num_attention_heads: int = 16 + num_hidden_layers: int = 28 + num_key_value_heads: int = 8 + head_dim: int = 128 + rms_norm_eps: float = 1e-06 + rope_scaling: dict | None = None + rope_theta: float = 1000000.0 + sliding_window: int = 32768 + tie_word_embeddings: bool = True + torch_dtype: str = "bfloat16" + transformers_version: str = "4.52.3" + use_cache: bool = True + use_sliding_window: bool = False + vocab_size: int = 159488 # text_vocab_size + speech_vocab_size + 2 (eos and task_id) + lm_head_bias: bool = False + qkv_bias: bool = False + fp16_flow: bool = False + speech_token_offset: int = 152927 + + @classmethod + def from_initial_and_json( + cls, + initial_values: Dict[str, Any] = None, + json_file: Optional[str] = None + ): + """ + Create instance from initial values and JSON data + + Args: + initial_values: Initial key-value dict, which will overrides all other configurations + json_file: JSON file path + + Returns: + SoulXPodcastLLMConfig instance + """ + # Merge all data sources + merged_data = {} + + # 1. Load from JSON file first (lowest priority) + if json_file and os.path.exists(json_file): + file_data = cls._load_json_file(json_file) + merged_data.update(file_data) + + # 2. Override with initial values last (highest priority) + if initial_values: + merged_data.update(initial_values) + + # 3. Extract dataclass fields + valid_fields = {f.name for f in fields(cls)} + init_data = {k: v for k, v in merged_data.items() if k in valid_fields} + + return cls(**init_data) + + @staticmethod + def _load_json_file(file_path: str) -> Dict[str, Any]: + """从JSON文件加载数据""" + path = Path(file_path) + if not path.exists(): + return {} + with open(path, 'r', encoding='utf-8') as f: + return json.load(f) + +class AutoPretrainedConfig(PretrainedConfig): + model_type = "qwen3" + + def __init__(self, **kwargs): + # Remove non-configuration parameters + config_kwargs = {k: v for k, v in kwargs.items() + if not k.startswith('_') and k != 'self'} + super().__init__(**config_kwargs) + + @classmethod + def from_dataclass(cls, dataclass_config): + """Dynamically generate config from dataclass""" + if not is_dataclass(dataclass_config): + raise ValueError("Input must be a dataclass instance") + + dataclass_dict = asdict(dataclass_config) + return cls(**dataclass_dict) + + +@dataclass +class SamplingParams: + temperature: float = 0.6 + repetition_penalty: float = 1.25 + top_k: int = 100 + top_p: float = 0.9 + max_tokens: int = 3000 + min_tokens: int = 8 + stop_token_ids: list[int] = field(default_factory=lambda: [151675]) + # RasSampler parameters + use_ras: bool = True + win_size: int = 25 + tau_r: float = 0.2 + + +@dataclass +class Config: + model: str + max_model_len: int = 8192 # 15s prompt + 30s generated audio for 25hz audio tokenizer + gpu_memory_utilization: float = 0.9 + tensor_parallel_size: int = 1 + enforce_eager: bool = False + hf_config: SoulXPodcastLLMConfig | AutoConfig = field(default_factory=SoulXPodcastLLMConfig) + eos: int = -1 + llm_engine: str = "hf" # support hf, nano-vllm + max_turn_size: int = 14 + turn_tokens_threshold: int = 6192 + + prompt_context: int = 2 # default to 2 for two-speaker podcast; + history_context: int = 4 + history_text_context: int = 4 + + def __post_init__(self): + assert os.path.isdir(self.model) + + max_pos = getattr(self.hf_config, "max_position_embeddings", 8192) + self.max_model_len = min(self.max_model_len, max_pos) \ No newline at end of file diff --git a/soulxpodcast/engine/__init__.py b/soulxpodcast/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/soulxpodcast/engine/__pycache__/__init__.cpython-311.pyc b/soulxpodcast/engine/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14d644b95fbbd21fac163e1ad120da2056264cef Binary files /dev/null and b/soulxpodcast/engine/__pycache__/__init__.cpython-311.pyc differ diff --git a/soulxpodcast/engine/__pycache__/__init__.cpython-312.pyc b/soulxpodcast/engine/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1d9488d4bac5730f0e8df808dd05226077086e6 Binary files /dev/null and b/soulxpodcast/engine/__pycache__/__init__.cpython-312.pyc differ diff --git a/soulxpodcast/engine/__pycache__/llm_engine.cpython-311.pyc b/soulxpodcast/engine/__pycache__/llm_engine.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fab081725de6b51108f94538c4fcdfaacb07e5a5 Binary files /dev/null and b/soulxpodcast/engine/__pycache__/llm_engine.cpython-311.pyc differ diff --git a/soulxpodcast/engine/__pycache__/llm_engine.cpython-312.pyc b/soulxpodcast/engine/__pycache__/llm_engine.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..475fe94c5381449730c03c3d65dece6b183075a1 Binary files /dev/null and b/soulxpodcast/engine/__pycache__/llm_engine.cpython-312.pyc differ diff --git a/soulxpodcast/engine/llm_engine.py b/soulxpodcast/engine/llm_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..5b6dad9a65d98709933ad6fab611325580d6a48a --- /dev/null +++ b/soulxpodcast/engine/llm_engine.py @@ -0,0 +1,116 @@ +import types +import atexit +from dataclasses import fields, asdict +from time import perf_counter +import os +from functools import partial + +import torch +import torch.multiprocessing as mp +from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList +from transformers import EosTokenCriteria, RepetitionPenaltyLogitsProcessor +try: + from vllm import LLM + from vllm import SamplingParams as VllmSamplingParams + from vllm.inputs import TokensPrompt as TokensPrompt + SUPPORT_VLLM = True +except ImportError: + SUPPORT_VLLM = False + +from soulxpodcast.models.modules.sampler import _ras_sample_hf_engine +from soulxpodcast.config import Config, SamplingParams + +class HFLLMEngine: + + def __init__(self, model, **kwargs): + config_fields = {field.name for field in fields(Config)} + config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields} + config = Config(model, **config_kwargs) + + self.tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) + config.eos = config.hf_config.eos_token_id # speech eos token; + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + self.model = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.bfloat16, device_map=self.device) + self.config = config + self.pad_token_id = self.tokenizer.pad_token_id + + def generate( + self, + prompt: list[str], + sampling_param: SamplingParams, + past_key_values=None, + ) -> dict: + + # Recreate stopping_criteria per request for thread safety + stopping_criteria = StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.config.hf_config.eos_token_id)]) + if sampling_param.use_ras: + sample_hf_engine_handler = partial(_ras_sample_hf_engine, + use_ras=sampling_param.use_ras, + win_size=sampling_param.win_size, tau_r=sampling_param.tau_r) + else: + sample_hf_engine_handler = None + rep_pen_processor = RepetitionPenaltyLogitsProcessor( + penalty=sampling_param.repetition_penalty, + prompt_ignore_length=len(prompt) + ) # exclude the input prompt, consistent with vLLM implementation; + with torch.no_grad(): # Avoids gradient computation with no_grad + input_len = len(prompt) + generated_ids = self.model.generate( + input_ids = torch.tensor([prompt], dtype=torch.int64).to(self.device), + do_sample=True, + top_k=sampling_param.top_k, + top_p=sampling_param.top_p, + min_new_tokens=sampling_param.min_tokens, + max_new_tokens=sampling_param.max_tokens, + temperature=sampling_param.temperature, + repetition_penalty=sampling_param.repetition_penalty, + stopping_criteria=stopping_criteria, + past_key_values=past_key_values, + custom_generate=sample_hf_engine_handler, + use_cache=True, + logits_processor=[rep_pen_processor] + ) + generated_ids = generated_ids[:, input_len:].cpu().numpy().tolist()[0] + output = { + "text": self.tokenizer.decode(generated_ids), + "token_ids": generated_ids, + } + return output + +class VLLMEngine: + + def __init__(self, model, **kwargs): + + config_fields = {field.name for field in fields(Config)} + config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields} + config = Config(model, **config_kwargs) + + self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True) + config.eos = config.hf_config.eos_token_id # speech eos token; + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + os.environ["VLLM_USE_V1"] = "0" + if SUPPORT_VLLM: + self.model = LLM(model=model, enforce_eager=True, dtype="bfloat16", max_model_len=8192, enable_prefix_caching=True) + else: + raise ImportError("Not Support VLLM now!!!") + self.config = config + self.pad_token_id = self.tokenizer.pad_token_id + + def generate( + self, + prompt: list[str], + sampling_param: SamplingParams, + past_key_values=None, + ) -> dict: + sampling_param.stop_token_ids = [self.config.hf_config.eos_token_id] + with torch.no_grad(): # Avoids gradient computation with no_grad + generated_ids = self.model.generate( + TokensPrompt(prompt_token_ids=prompt), + VllmSamplingParams(**asdict(sampling_param)), + use_tqdm=False, + )[0].outputs[0].token_ids + output = { + "text": self.tokenizer.decode(generated_ids), + "token_ids": list(generated_ids), + } + return output \ No newline at end of file diff --git a/soulxpodcast/models/__pycache__/soulxpodcast.cpython-311.pyc b/soulxpodcast/models/__pycache__/soulxpodcast.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4208f5aa82555e1fc925d7e03f0c3a5c3fd0bda Binary files /dev/null and b/soulxpodcast/models/__pycache__/soulxpodcast.cpython-311.pyc differ diff --git a/soulxpodcast/models/__pycache__/soulxpodcast.cpython-312.pyc b/soulxpodcast/models/__pycache__/soulxpodcast.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a26efa784329dab9a79b09c1bbdda5ce23298ed2 Binary files /dev/null and b/soulxpodcast/models/__pycache__/soulxpodcast.cpython-312.pyc differ diff --git a/soulxpodcast/models/modules/__init__.py b/soulxpodcast/models/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/soulxpodcast/models/modules/__pycache__/__init__.cpython-311.pyc b/soulxpodcast/models/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea133e7078e6b50482b521bbe4fa83e7dd99d6ab Binary files /dev/null and b/soulxpodcast/models/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/soulxpodcast/models/modules/__pycache__/__init__.cpython-312.pyc b/soulxpodcast/models/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39214f38dc91a480a304bca47c7d038dd27e3f0e Binary files /dev/null and b/soulxpodcast/models/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/soulxpodcast/models/modules/__pycache__/flow.cpython-311.pyc b/soulxpodcast/models/modules/__pycache__/flow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..037adef8890fb9edfe413bd6409590aa1bed051c Binary files /dev/null and b/soulxpodcast/models/modules/__pycache__/flow.cpython-311.pyc differ diff --git a/soulxpodcast/models/modules/__pycache__/flow.cpython-312.pyc b/soulxpodcast/models/modules/__pycache__/flow.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c48995602281f9df68d016834a171ad7fa955e6 Binary files /dev/null and b/soulxpodcast/models/modules/__pycache__/flow.cpython-312.pyc differ diff --git a/soulxpodcast/models/modules/__pycache__/hifigan.cpython-311.pyc b/soulxpodcast/models/modules/__pycache__/hifigan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6589ac194ad7d4efa82abff4d83ce3091cd3dc0 Binary files /dev/null and b/soulxpodcast/models/modules/__pycache__/hifigan.cpython-311.pyc differ diff --git a/soulxpodcast/models/modules/__pycache__/hifigan.cpython-312.pyc b/soulxpodcast/models/modules/__pycache__/hifigan.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..764d6d1925de1359d41f4a9c18092bc61c4eb8e7 Binary files /dev/null and b/soulxpodcast/models/modules/__pycache__/hifigan.cpython-312.pyc differ diff --git a/soulxpodcast/models/modules/__pycache__/sampler.cpython-311.pyc b/soulxpodcast/models/modules/__pycache__/sampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..691c2a7db18d963cf9c90b76537e6ea5bb2ac1c3 Binary files /dev/null and b/soulxpodcast/models/modules/__pycache__/sampler.cpython-311.pyc differ diff --git a/soulxpodcast/models/modules/__pycache__/sampler.cpython-312.pyc b/soulxpodcast/models/modules/__pycache__/sampler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c15ffa6a1f4e61eff137b850e9a3a194353729c2 Binary files /dev/null and b/soulxpodcast/models/modules/__pycache__/sampler.cpython-312.pyc differ diff --git a/soulxpodcast/models/modules/flow.py b/soulxpodcast/models/modules/flow.py new file mode 100644 index 0000000000000000000000000000000000000000..d54397cfe974e21bddeeb40bf3521da17527a040 --- /dev/null +++ b/soulxpodcast/models/modules/flow.py @@ -0,0 +1,197 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from soulxpodcast.models.modules.flow_components.estimator import \ + CausalConditionalDecoder +from soulxpodcast.models.modules.flow_components.upsample_encoder import ( + UpsampleConformerEncoder, make_pad_mask) + + +@dataclass +class CfmParams: + sigma_min: float = 1e-6 + solver: str = "euler" + t_scheduler: str = "cosine" + training_cfg_rate: float = 0.2 + inference_cfg_rate: float = 0.7 + + +class CausalConditionalCFM(torch.nn.Module): + def __init__(self, in_channels=320, cfm_params=CfmParams(), n_spks=1, spk_emb_dim=80, estimator: torch.nn.Module = None): + super().__init__() + self.n_feats = in_channels + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.solver = cfm_params.solver + if hasattr(cfm_params, "sigma_min"): + self.sigma_min = cfm_params.sigma_min + else: + self.sigma_min = 1e-4 + self.t_scheduler = cfm_params.t_scheduler + self.training_cfg_rate = cfm_params.training_cfg_rate + self.inference_cfg_rate = cfm_params.inference_cfg_rate + in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) + # Just change the architecture of the estimator here + self.estimator = CausalConditionalDecoder() if estimator is None else estimator + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature + # fix prompt and overlap part mu and z + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) + if self.t_scheduler == 'cosine': + t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None + + def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + batch_size = x.size(0) + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + # Do not use concat, it may cause memory format changed and trt infer with wrong results! + # Create tensors with double batch size for CFG (conditional + unconditional) + x_in = torch.zeros([batch_size * 2, x.size(1), x.size(2)], device=x.device, dtype=x.dtype) + mask_in = torch.zeros([batch_size * 2, mask.size(1), mask.size(2)], device=x.device, dtype=x.dtype) + mu_in = torch.zeros([batch_size * 2, mu.size(1), mu.size(2)], device=x.device, dtype=x.dtype) + t_in = torch.zeros([batch_size * 2], device=x.device, dtype=x.dtype) + spks_in = torch.zeros([batch_size * 2, spks.size(1)], device=x.device, dtype=x.dtype) + cond_in = torch.zeros([batch_size * 2, cond.size(1), cond.size(2)], device=x.device, dtype=x.dtype) + + for step in range(1, len(t_span)): + # Classifier-Free Guidance inference introduced in VoiceBox + # Copy conditional and unconditional input + x_in[:batch_size] = x + x_in[batch_size:] = x + mask_in[:batch_size] = mask + mask_in[batch_size:] = mask + mu_in[:batch_size] = mu + # Unconditional part remains 0 + t_in.fill_(t) + spks_in[:batch_size] = spks + cond_in[:batch_size] = cond + + dphi_dt = self.estimator( + x_in, mask_in, + mu_in, t_in, + spks_in, + cond_in, + streaming + ) + dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [batch_size, batch_size], dim=0) + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1].float() + + +class CausalMaskedDiffWithXvec(torch.nn.Module): + def __init__( + self, + input_size: int = 512, + output_size: int = 80, + spk_embed_dim: int = 192, + output_type: str = "mel", + vocab_size: int = 6561, + input_frame_rate: int = 25, + token_mel_ratio: int = 2, + pre_lookahead_len: int = 3, + encoder: torch.nn.Module = None, + decoder: torch.nn.Module = None, + ): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.vocab_size = vocab_size + self.output_type = output_type + self.input_frame_rate = input_frame_rate + self.input_embedding = nn.Embedding(vocab_size, input_size) + self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size) + self.encoder = UpsampleConformerEncoder() if encoder is None else encoder + self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) + self.decoder = CausalConditionalCFM() if decoder is None else decoder + self.token_mel_ratio = token_mel_ratio + self.pre_lookahead_len = pre_lookahead_len + + @torch.inference_mode() + def forward(self, + token, + token_len, + prompt_feat, + prompt_feat_len, + embedding, + streaming, + finalize): + # xvec projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + + # concat text and prompt_text + mask = (~make_pad_mask(token_len, max_len=token.shape[1])).unsqueeze(-1).to(embedding) + token = self.input_embedding(torch.clamp(token, min=0)) * mask + + # text encode + if finalize is True: + h, h_lengths = self.encoder(token, token_len, streaming=streaming) + else: + token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:] + h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming) + h = self.encoder_proj(h) + + # get conditions + conds = torch.zeros_like(h, device=token.device) + for i, j in enumerate(prompt_feat_len): + conds[i, :j] = prompt_feat[i, :j] + conds = conds.transpose(1, 2) + + h_lengths = h_lengths.sum(dim=-1).squeeze(dim=1) + mask = (~make_pad_mask(h_lengths, max_len=h.shape[1])).to(h) + feat, _ = self.decoder( + mu=h.transpose(1, 2).contiguous(), + mask=mask.unsqueeze(1), + spks=embedding, + cond=conds, + n_timesteps=15, + streaming=streaming + ) # [B, num_mels, T] + return feat.float(), h_lengths diff --git a/soulxpodcast/models/modules/flow_components/__init__.py b/soulxpodcast/models/modules/flow_components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/soulxpodcast/models/modules/flow_components/__pycache__/__init__.cpython-311.pyc b/soulxpodcast/models/modules/flow_components/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3da688e17b3ff363d626b553df09e8423882f2f Binary files /dev/null and b/soulxpodcast/models/modules/flow_components/__pycache__/__init__.cpython-311.pyc differ diff --git a/soulxpodcast/models/modules/flow_components/__pycache__/__init__.cpython-312.pyc b/soulxpodcast/models/modules/flow_components/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90c33cb43afdf7830bbf76cee7caf61295a24d4a Binary files /dev/null and b/soulxpodcast/models/modules/flow_components/__pycache__/__init__.cpython-312.pyc differ diff --git a/soulxpodcast/models/modules/flow_components/__pycache__/estimator.cpython-311.pyc b/soulxpodcast/models/modules/flow_components/__pycache__/estimator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fc5fd0d9c9436d0323a68f9553f04f318006b53 Binary files /dev/null and b/soulxpodcast/models/modules/flow_components/__pycache__/estimator.cpython-311.pyc differ diff --git a/soulxpodcast/models/modules/flow_components/__pycache__/estimator.cpython-312.pyc b/soulxpodcast/models/modules/flow_components/__pycache__/estimator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..839894013cee60867ff93f8f6a70a5cf0dbd257a Binary files /dev/null and b/soulxpodcast/models/modules/flow_components/__pycache__/estimator.cpython-312.pyc differ diff --git a/soulxpodcast/models/modules/flow_components/__pycache__/upsample_encoder.cpython-311.pyc b/soulxpodcast/models/modules/flow_components/__pycache__/upsample_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bd7a6b997804cbd355d2fe4d1ee4e6ce7be45d2 Binary files /dev/null and b/soulxpodcast/models/modules/flow_components/__pycache__/upsample_encoder.cpython-311.pyc differ diff --git a/soulxpodcast/models/modules/flow_components/__pycache__/upsample_encoder.cpython-312.pyc b/soulxpodcast/models/modules/flow_components/__pycache__/upsample_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1499bac84884f8b07b7dbe15e26d7f8d9781786 Binary files /dev/null and b/soulxpodcast/models/modules/flow_components/__pycache__/upsample_encoder.cpython-312.pyc differ diff --git a/soulxpodcast/models/modules/flow_components/estimator.py b/soulxpodcast/models/modules/flow_components/estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..0c965a8e678219833f0646796e21fce7e604cb61 --- /dev/null +++ b/soulxpodcast/models/modules/flow_components/estimator.py @@ -0,0 +1,974 @@ +import math +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.attention import (GEGLU, GELU, AdaLayerNorm, + AdaLayerNormZero, ApproximateGELU) +from diffusers.models.attention_processor import Attention +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.utils.torch_utils import maybe_allow_in_graph +from einops import pack, rearrange, repeat + +from soulxpodcast.models.modules.flow_components.upsample_encoder import \ + add_optional_chunk_mask + + +def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + assert mask.dtype == torch.bool + assert dtype in [torch.float32, torch.bfloat16, torch.float16] + mask = mask.to(dtype) + # attention mask bias + # NOTE(Mddct): torch.finfo jit issues + # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min + mask = (1.0 - mask) * -1.0e+10 + return mask + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + + Args: + in_features: shape of the input + out_features: shape of the output + alpha: trainable parameter that controls frequency + alpha_trainable: whether alpha is trainable + alpha_logscale: whether to use log scale for alpha + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + + def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + super().__init__() + self.in_features = out_features if isinstance(out_features, list) else [out_features] + self.proj = LoRACompatibleLinear(in_features, out_features) + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) + self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + x = self.proj(x) + if self.alpha_logscale: + alpha = torch.exp(self.alpha) + beta = torch.exp(self.beta) + else: + alpha = self.alpha + beta = self.beta + + x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) + + return x + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + elif activation_fn == "snakebeta": + act_fn = SnakeBeta(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + # scale_qk=False, # uncomment this to not to use flash attention + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class SinusoidalPosEmb(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Block1D(torch.nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.block = torch.nn.Sequential( + torch.nn.Conv1d(dim, dim_out, 3, padding=1), + torch.nn.GroupNorm(groups, dim_out), + nn.Mish(), + ) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask + + +class ResnetBlock1D(torch.nn.Module): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super().__init__() + self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)) + + self.block1 = Block1D(dim, dim_out, groups=groups) + self.block2 = Block1D(dim_out, dim_out, groups=groups) + + self.res_conv = torch.nn.Conv1d(dim, dim_out, 1) + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + h += self.mlp(time_emb).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + + +class Downsample1D(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + ): + super().__init__() + assert act_fn == "silu", "act_fn must be silu" + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = nn.SiLU() + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = nn.SiLU() + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Upsample1D(nn.Module): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, inputs): + assert inputs.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + + outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") + + if self.use_conv: + outputs = self.conv(outputs) + + return outputs + + +class Transpose(torch.nn.Module): + def __init__(self, dim0: int, dim1: int): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.transpose(x, self.dim0, self.dim1) + return x + + +class CausalConv1d(torch.nn.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + super(CausalConv1d, self).__init__(in_channels, out_channels, + kernel_size, stride, + padding=0, dilation=dilation, + groups=groups, bias=bias, + padding_mode=padding_mode, + device=device, dtype=dtype) + assert stride == 1 + self.causal_padding = kernel_size - 1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, (self.causal_padding, 0), value=0.0) + x = super(CausalConv1d, self).forward(x) + return x + + +class CausalBlock1D(Block1D): + def __init__(self, dim: int, dim_out: int): + super(CausalBlock1D, self).__init__(dim, dim_out) + self.block = torch.nn.Sequential( + CausalConv1d(dim, dim_out, 3), + Transpose(1, 2), + nn.LayerNorm(dim_out), + Transpose(1, 2), + nn.Mish(), + ) + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + output = self.block(x * mask) + return output * mask + + +class CausalResnetBlock1D(ResnetBlock1D): + def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): + super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) + self.block1 = CausalBlock1D(dim, dim_out) + self.block2 = CausalBlock1D(dim_out, dim_out) + + +class ConditionalDecoder(nn.Module): + """ + This decoder requires an input with the same shape of the target. So, if your text content + is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. + + Args: + in_channels: number of input channels + out_channels: number of output channels + channels: tuple of channel dimensions + dropout: dropout rate + attention_head_dim: dimension of attention heads + n_blocks: number of transformer blocks + num_mid_blocks: number of middle blocks + num_heads: number of attention heads + act_fn: activation function name + """ + + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + ): + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for _ in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] * 2 + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + resnet = ResnetBlock1D( + dim=input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + self.initialize_weights() + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t).to(t.dtype) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + if cond is not None: + x = pack([x, cond], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + skip = hiddens.pop() + x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] + x = resnet(x, mask_up, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + x = upsample(x * mask_up) + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + return output * mask + + +class CausalConditionalDecoder(ConditionalDecoder): + """ + This decoder requires an input with the same shape of the target. So, if your text content + is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. + + Args: + in_channels: number of input channels + out_channels: number of output channels + channels: list of channel dimensions + dropout: dropout rate + attention_head_dim: dimension of attention heads + n_blocks: number of transformer blocks + num_mid_blocks: number of middle blocks + num_heads: number of attention heads + act_fn: activation function name + static_chunk_size: size of static chunks + num_decoding_left_chunks: number of left chunks for decoding + """ + + def __init__( + self, + in_channels=320, + out_channels=80, + channels=[256], # noqa + dropout=0.0, + attention_head_dim=64, + n_blocks=4, + num_mid_blocks=12, + num_heads=8, + act_fn="gelu", + static_chunk_size=50, + num_decoding_left_chunks=-1, + ): + torch.nn.Module.__init__(self) + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + self.static_chunk_size = static_chunk_size + self.num_decoding_left_chunks = num_decoding_left_chunks + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) + ) + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for _ in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] * 2 + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + resnet = CausalResnetBlock1D( + dim=input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else CausalConv1d(output_channel, output_channel, 3) + ) + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + self.final_block = CausalBlock1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + self.initialize_weights() + + def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + t = self.time_embeddings(t).to(t.dtype) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + if cond is not None: + x = pack([x, cond], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c").contiguous() + if streaming is True: + attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1) + else: + attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c").contiguous() + if streaming is True: + attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1) + else: + attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + skip = hiddens.pop() + x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] + x = resnet(x, mask_up, t) + x = rearrange(x, "b c t -> b t c").contiguous() + if streaming is True: + attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1) + else: + attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + x = upsample(x * mask_up) + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + return output * mask diff --git a/soulxpodcast/models/modules/flow_components/upsample_encoder.py b/soulxpodcast/models/modules/flow_components/upsample_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5c753076211c3fd16599cd61a484f92ebdc0f23f --- /dev/null +++ b/soulxpodcast/models/modules/flow_components/upsample_encoder.py @@ -0,0 +1,998 @@ +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks + pos_idx = torch.arange(size, device=device) + block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size + ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) + return ret + + +def add_optional_chunk_mask(xs: torch.Tensor, + masks: torch.Tensor, + use_dynamic_chunk: bool, + use_dynamic_left_chunk: bool, + decoding_chunk_size: int, + static_chunk_size: int, + num_decoding_left_chunks: int, + enable_full_context: bool = True): + """ Apply optional mask for encoder. + + Args: + xs (torch.Tensor): padded input, (B, L, D), L for max length + mask (torch.Tensor): mask for xs, (B, 1, L) + use_dynamic_chunk (bool): whether to use dynamic chunk or not + use_dynamic_left_chunk (bool): whether to use dynamic left chunk for + training. + decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + static_chunk_size (int): chunk size for static chunk training/decoding + if it's greater than 0, if use_dynamic_chunk is true, + this parameter will be ignored + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + enable_full_context (bool): + True: chunk size is either [1, 25] or full context(max_len) + False: chunk size ~ U[1, 25] + + Returns: + torch.Tensor: chunk mask of the input xs. + """ + # Whether to use chunk mask or not + if use_dynamic_chunk: + max_len = xs.size(1) + if decoding_chunk_size < 0: + chunk_size = max_len + num_left_chunks = -1 + elif decoding_chunk_size > 0: + chunk_size = decoding_chunk_size + num_left_chunks = num_decoding_left_chunks + else: + # chunk size is either [1, 25] or full context(max_len). + # Since we use 4 times subsampling and allow up to 1s(100 frames) + # delay, the maximum frame is 100 / 4 = 25. + chunk_size = torch.randint(1, max_len, (1, )).item() + num_left_chunks = -1 + if chunk_size > max_len // 2 and enable_full_context: + chunk_size = max_len + else: + chunk_size = chunk_size % 25 + 1 + if use_dynamic_left_chunk: + max_left_chunks = (max_len - 1) // chunk_size + num_left_chunks = torch.randint(0, max_left_chunks, + (1, )).item() + chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + elif static_chunk_size > 0: + num_left_chunks = num_decoding_left_chunks + chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + else: + chunk_masks = masks + assert chunk_masks.dtype == torch.bool + if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0: + print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!') + chunk_masks[chunk_masks.sum(dim=-1) == 0] = True + return chunk_masks + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.size(0) + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = torch.arange(0, + max_len, + dtype=torch.int64, + device=lengths.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask + + +class EspnetRelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module (new implementation). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + Args: + d_model (int): Embedding dimension. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model: int, max_len: int = 5000): + super(EspnetRelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: torch.Tensor): + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.position_encoding(size=x.size(1), offset=offset) + return x, pos_emb + + def position_encoding(self, + offset: Union[int, torch.Tensor], + size: int) -> torch.Tensor: + """ For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int or torch.tensor): start offset + size (int): required size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + # How to subscript a Union type: + # https://github.com/pytorch/pytorch/issues/69434 + if isinstance(offset, int): + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset, + ] + elif isinstance(offset, torch.Tensor): + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset, + ] + return pos_emb + + +class LinearNoSubsampling(torch.nn.Module): + """Linear transform the input without subsampling + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + pos_enc_class (torch.nn.Module): Positional encoding class. + + """ + + def __init__(self, idim: int, odim: int, + pos_enc_class: torch.nn.Module): + super().__init__() + self.out = torch.nn.Sequential( + torch.nn.Linear(idim, odim), + torch.nn.LayerNorm(odim, eps=1e-5), + ) + self.pos_enc = pos_enc_class + self.right_context = 0 + self.subsampling_rate = 1 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.out(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask + + def position_encoding(self, offset: Union[int, torch.Tensor], + size: int) -> torch.Tensor: + return self.pos_enc.position_encoding(offset, size) + + +class Upsample1D(nn.Module): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels: int, out_channels: int, stride: int = 2): + super().__init__() + self.channels = channels + self.out_channels = out_channels + self.stride = stride + # In this mode, first repeat interpolate, than conv with stride=1 + self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0) + + def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest") + outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0) + outputs = self.conv(outputs) + return outputs, input_lengths * self.stride + + +class PreLookaheadLayer(nn.Module): + def __init__(self, channels: int, pre_lookahead_len: int = 1): + super().__init__() + self.channels = channels + self.pre_lookahead_len = pre_lookahead_len + self.conv1 = nn.Conv1d( + channels, channels, + kernel_size=pre_lookahead_len + 1, + stride=1, padding=0, + ) + self.conv2 = nn.Conv1d( + channels, channels, + kernel_size=3, stride=1, padding=0, + ) + + def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor: + """ + inputs: (batch_size, seq_len, channels) + """ + outputs = inputs.transpose(1, 2).contiguous() + context = context.transpose(1, 2).contiguous() + # look ahead + if context.size(2) == 0: + outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0) + else: + assert self.training is False, 'you have passed context, make sure that you are running inference mode' + assert context.size(2) == self.pre_lookahead_len + outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0) + outputs = F.leaky_relu(self.conv1(outputs)) + # outputs + outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0) + outputs = self.conv2(outputs) + outputs = outputs.transpose(1, 2).contiguous() + + # residual connection + outputs = outputs + inputs + return outputs + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + key_bias (bool): Whether to use bias in key linear layer. + + """ + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + key_bias: bool = True): + super().__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor, size + (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor, size + (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor, size + (#batch, n_head, time2, d_k). + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention( + self, + value: torch.Tensor, + scores: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) + ) -> torch.Tensor: + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value, size + (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score, size + (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask, size (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + # NOTE(xcsong): When will `if mask.size(2) > 0` be True? + # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the + # 1st chunk to ease the onnx export.] + # 2. pytorch training + if mask.size(2) > 0: # time2 > 0 + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + # For last chunk, time2 might be larger than scores.size(-1) + mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) + scores = scores.masked_fill(mask, -float('inf')) + attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0) # (batch, head, time1, time2) + # NOTE(xcsong): When will `if mask.size(2) > 0` be False? + # 1. onnx(16/-1, -1/-1, 16/0) + # 2. jit (16/-1, -1/-1, 16/0, 16/4) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + 1.When applying cross attention between decoder and encoder, + the batch padding mask for input is in (#batch, 1, T) shape. + 2.When applying self attention of encoder, + the mask is in (#batch, T, T) shape. + 3.When applying self attention of decoder, + the mask is in (#batch, L, L) shape. + 4.If the different position in decoder see different block + of the encoder, such as Mocha, the passed in mask could be + in (#batch, L, T) shape. But there is no such case in current + CosyVoice. + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + """ + q, k, v = self.forward_qkv(query, key, value) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask), new_cache + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + key_bias (bool): Whether to use bias in key linear layer. + """ + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + key_bias: bool = True): + super().__init__(n_head, n_feat, dropout_rate, key_bias) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x: torch.Tensor) -> torch.Tensor: + """Compute relative positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + torch.Tensor: Output tensor. + + """ + zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(x.size()[0], + x.size()[1], + x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x)[ + :, :, :, : x.size(-1) // 2 + 1 + ] # only keep the positions from 0 to time2 + return x + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used + if matrix_ac.shape != matrix_bd.shape: + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask), new_cache + + +class PositionwiseFeedForward(torch.nn.Module): + """Positionwise feed forward layer. + + FeedForward are appied on each position of the sequence. + The output dim is same with the input dim. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (torch.nn.Module): Activation function + """ + + def __init__( + self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + ): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = torch.nn.Linear(idim, hidden_units) + self.activation = activation + self.dropout = torch.nn.Dropout(dropout_rate) + self.w_2 = torch.nn.Linear(hidden_units, idim) + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + """ + return self.w_2(self.dropout(self.activation(self.w_1(xs)))) + + +class ConformerEncoderLayer(nn.Module): + """Encoder layer module. + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module + instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: Optional[nn.Module] = None, + feed_forward_macaron: Optional[nn.Module] = None, + conv_module: Optional[nn.Module] = None, + dropout_rate: float = 0.0, + normalize_before: bool = True, + ): + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module + self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module + self.norm_final = nn.LayerNorm( + size, eps=1e-12) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): positional encoding, must not be None + for ConformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2) + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, + att_cache) + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + # Fake new cnn cache here, and then change it in conv_module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + return x, mask, new_att_cache, new_cnn_cache + + +class UpsampleConformerEncoder(torch.nn.Module): + """ + Args: + input_size (int): input dim + output_size (int): dimension of attention + attention_heads (int): the number of heads of multi head attention + linear_units (int): the hidden units number of position-wise feed + forward + num_blocks (int): the number of decoder blocks + static_chunk_size (int): chunk size for static chunk training and + decoding + use_dynamic_chunk (bool): whether use dynamic chunk size for + training or not, You can only use fixed chunk(chunk_size > 0) + or dyanmic chunk size(use_dynamic_chunk = True) + use_dynamic_left_chunk (bool): whether use dynamic left chunk in + dynamic chunk training + key_bias: whether use bias in attention.linear_k, False for whisper models. + """ + + def __init__( + self, + input_size: int = 512, + output_size: int = 512, + attention_heads: int = 8, + linear_units: int = 2048, + num_blocks: int = 6, + static_chunk_size: int = 25, + use_dynamic_chunk: bool = False, + use_dynamic_left_chunk: bool = False, + key_bias: bool = True, + ): + super().__init__() + self._output_size = output_size + + self.embed = LinearNoSubsampling( + input_size, output_size, + EspnetRelPositionalEncoding(output_size), + ) + + self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + activation = torch.nn.SiLU() + # self-attention module definition + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + 0.0, + key_bias, + ) + # feed-forward module definition + positionwise_layer_args = ( + output_size, + linear_units, + 0.0, + activation, + ) + # convolution module definition + self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3) + self.encoders = torch.nn.ModuleList([ + ConformerEncoderLayer( + output_size, + RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args), + PositionwiseFeedForward(*positionwise_layer_args), + ) for _ in range(num_blocks) + ]) + self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2) + self.up_embed = LinearNoSubsampling( + input_size, output_size, + EspnetRelPositionalEncoding(output_size), + ) + self.up_encoders = torch.nn.ModuleList([ + ConformerEncoderLayer( + output_size, + RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args), + PositionwiseFeedForward(*positionwise_layer_args), + ) for _ in range(4) + ]) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + context: torch.Tensor = torch.zeros(0, 0, 0), + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + streaming: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed positions in tensor. + + Args: + xs: padded input tensor (B, T, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: + encoder output tensor xs, and subsampled masks + xs: padded output tensor (B, T' ~= T/subsample_rate, D) + masks: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + NOTE(xcsong): + We pass the `__call__` method of the modules instead of `forward` to the + checkpointing API because `__call__` attaches all the hooks of the module. + https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + """ + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + xs, pos_emb, masks = self.embed(xs, masks) + if context.size(1) != 0: + assert self.training is False, 'you have passed context, make sure that you are running inference mode' + context_masks = torch.ones(1, 1, context.size(1)).to(masks) + context, _, _ = self.embed(context, context_masks, offset=xs.size(1)) + mask_pad = masks # (B, 1, T/subsample_rate) + chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1) + # lookahead + conformer encoder + xs = self.pre_lookahead_layer(xs, context=context) + xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) + + # upsample + conformer encoder + xs = xs.transpose(1, 2).contiguous() + xs, xs_lens = self.up_layer(xs, xs_lens) + xs = xs.transpose(1, 2).contiguous() + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + xs, pos_emb, masks = self.up_embed(xs, masks) + mask_pad = masks # (B, 1, T/subsample_rate) + chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1) + xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad) + + xs = self.after_norm(xs) + # Here we assume the mask is not changed in encoder layers, so just + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later + return xs, masks + + def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: + for layer in self.encoders: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + return xs + + def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: + for layer in self.up_encoders: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + return xs diff --git a/soulxpodcast/models/modules/hifigan.py b/soulxpodcast/models/modules/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc41bb4c03e228627bce1a4ec99ca7a44a581d8 --- /dev/null +++ b/soulxpodcast/models/modules/hifigan.py @@ -0,0 +1,249 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""HIFI-GAN""" + +from typing import Dict, List + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.signal import get_window +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm + +try: + from torch.nn.utils.parametrizations import weight_norm +except ImportError: + from torch.nn.utils import weight_norm # noqa + +from soulxpodcast.models.modules.hifigan_components.layers import ( + ResBlock, SourceModuleHnNSF, SourceModuleHnNSF2, init_weights) + + +class ConvRNNF0Predictor(nn.Module): + def __init__(self, + num_class: int = 1, + in_channels: int = 80, + cond_channels: int = 512 + ): + super().__init__() + + self.num_class = num_class + self.condnet = nn.Sequential( + weight_norm( # noqa + nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( # noqa + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( # noqa + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( # noqa + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( # noqa + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + ) + self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.condnet(x) + x = x.transpose(1, 2) + return torch.abs(self.classifier(x).squeeze(-1)) + + +class HiFTGenerator(nn.Module): + """ + HiFTNet Generator: Neural Source Filter + ISTFTNet + https://arxiv.org/abs/2309.09493 + """ + def __init__( + self, + in_channels: int = 80, + base_channels: int = 512, + nb_harmonics: int = 8, + sampling_rate: int = 24000, + nsf_alpha: float = 0.1, + nsf_sigma: float = 0.003, + nsf_voiced_threshold: float = 10, + upsample_rates: List[int] = [8, 5, 3], # noqa + upsample_kernel_sizes: List[int] = [16, 11, 7], # noqa + istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4}, # noqa + resblock_kernel_sizes: List[int] = [3, 7, 11], # noqa + resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], # noqa + source_resblock_kernel_sizes: List[int] = [7, 7, 11], # noqa + source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], # noqa + lrelu_slope: float = 0.1, + audio_limit: float = 0.99, + f0_predictor: torch.nn.Module = None, + ): + super(HiFTGenerator, self).__init__() + + self.out_channels = 1 + self.nb_harmonics = nb_harmonics + self.sampling_rate = sampling_rate + self.istft_params = istft_params + self.lrelu_slope = lrelu_slope + self.audio_limit = audio_limit + + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + # NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation + this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2 + self.m_source = this_SourceModuleHnNSF( + sampling_rate=sampling_rate, + upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"], + harmonic_num=nb_harmonics, + sine_amp=nsf_alpha, + add_noise_std=nsf_sigma, + voiced_threshod=nsf_voiced_threshold) + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"]) + + self.conv_pre = weight_norm( # noqa + Conv1d(in_channels, base_channels, 7, 1, padding=3) + ) + + # Up + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( # noqa + ConvTranspose1d( + base_channels // (2**i), + base_channels // (2**(i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + # Down + self.source_downs = nn.ModuleList() + self.source_resblocks = nn.ModuleList() + downsample_rates = [1] + upsample_rates[::-1][:-1] + downsample_cum_rates = np.cumprod(downsample_rates) + for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)): + if u == 1: + self.source_downs.append( + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1) + ) + else: + self.source_downs.append( + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2)) + ) + + self.source_resblocks.append( + ResBlock(base_channels // (2 ** (i + 1)), k, d) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = base_channels // (2**(i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(ResBlock(ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)) # noqa + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.reflection_pad = nn.ReflectionPad1d((1, 0)) + self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32)) + self.f0_predictor = ConvRNNF0Predictor() if f0_predictor is None else f0_predictor + + def remove_weight_norm(self): + print('Removing weight norm...') + for up in self.ups: + remove_weight_norm(up) + for resblock in self.resblocks: + resblock.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + self.m_source.remove_weight_norm() + for source_down in self.source_downs: + remove_weight_norm(source_down) + for source_resblock in self.source_resblocks: + source_resblock.remove_weight_norm() + + def _stft(self, x): + spec = torch.stft( + x, + self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device), + return_complex=True) + spec = torch.view_as_real(spec) # [B, F, TT, 2] + return spec[..., 0], spec[..., 1] + + def _istft(self, magnitude, phase): + magnitude = torch.clip(magnitude, max=1e2) + real = magnitude * torch.cos(phase) + img = magnitude * torch.sin(phase) + inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], + self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) + return inverse_transform + + def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: + s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) + s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) + + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, self.lrelu_slope) + x = self.ups[i](x) + + if i == self.num_upsamples - 1: + x = self.reflection_pad(x) + + # fusion + si = self.source_downs[i](s_stft) + si = self.source_resblocks[i](si) + x = x + si + + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + x = F.leaky_relu(x) + x = self.conv_post(x) + magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :]) + phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy + + x = self._istft(magnitude, phase) + x = torch.clamp(x*0.98, -self.audio_limit, self.audio_limit) + return x + + @torch.inference_mode() + def forward(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: + # mel->f0 + f0 = self.f0_predictor(speech_feat) + # f0->source + s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + s, _, _ = self.m_source(s) + s = s.transpose(1, 2) + # use cache_source to avoid glitch + if cache_source.shape[2] != 0: + s[:, :, :cache_source.shape[2]] = cache_source + generated_speech = self.decode(x=speech_feat, s=s) + return generated_speech, s diff --git a/soulxpodcast/models/modules/hifigan_components/__init__.py b/soulxpodcast/models/modules/hifigan_components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/soulxpodcast/models/modules/hifigan_components/__pycache__/__init__.cpython-311.pyc b/soulxpodcast/models/modules/hifigan_components/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3203f3336a86b5f6c4db24ea32ac02012c7386c3 Binary files /dev/null and b/soulxpodcast/models/modules/hifigan_components/__pycache__/__init__.cpython-311.pyc differ diff --git a/soulxpodcast/models/modules/hifigan_components/__pycache__/__init__.cpython-312.pyc b/soulxpodcast/models/modules/hifigan_components/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6eee827b6ee3e3c3580a9e1479b297d2fc077e35 Binary files /dev/null and b/soulxpodcast/models/modules/hifigan_components/__pycache__/__init__.cpython-312.pyc differ diff --git a/soulxpodcast/models/modules/hifigan_components/__pycache__/layers.cpython-311.pyc b/soulxpodcast/models/modules/hifigan_components/__pycache__/layers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da7f3ea2af92b7effaea6e71a4a66d65e7e3fc2d Binary files /dev/null and b/soulxpodcast/models/modules/hifigan_components/__pycache__/layers.cpython-311.pyc differ diff --git a/soulxpodcast/models/modules/hifigan_components/__pycache__/layers.cpython-312.pyc b/soulxpodcast/models/modules/hifigan_components/__pycache__/layers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..303360a943ad0ad6fd581ae70bb3f1aec822fe07 Binary files /dev/null and b/soulxpodcast/models/modules/hifigan_components/__pycache__/layers.cpython-312.pyc differ diff --git a/soulxpodcast/models/modules/hifigan_components/layers.py b/soulxpodcast/models/modules/hifigan_components/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..e07c58725d8ee2bccffdb138300dac5500e2eda8 --- /dev/null +++ b/soulxpodcast/models/modules/hifigan_components/layers.py @@ -0,0 +1,433 @@ +from typing import List + +import numpy as np +import torch +import torch.nn as nn +from torch.distributions.uniform import Uniform +from torch.nn import Conv1d +from torch.nn.utils import remove_weight_norm + +try: + from torch.nn.utils.parametrizations import weight_norm +except ImportError: + from torch.nn.utils import weight_norm # noqa + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +"""hifigan based generator implementation. + +This code is modified from https://github.com/jik876/hifi-gan + ,https://github.com/kan-bayashi/ParallelWaveGAN and + https://github.com/NVIDIA/BigVGAN + +""" + + +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + + Args: + in_features: shape of the input + alpha: trainable parameter + alpha_trainable: whether alpha is trainable + alpha_logscale: whether to use log scale for alpha + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) + + return x + + +class ResBlock(torch.nn.Module): + """Residual block module in HiFiGAN/BigVGAN.""" + def __init__( + self, + channels: int = 512, + kernel_size: int = 3, + dilations: List[int] = [1, 3, 5], # noqa + ): + super(ResBlock, self).__init__() + self.convs1 = nn.ModuleList() + self.convs2 = nn.ModuleList() + + for dilation in dilations: + self.convs1.append( + weight_norm( # noqa + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation, + padding=get_padding(kernel_size, dilation) + ) + ) + ) + self.convs2.append( + weight_norm( # noqa + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1) + ) + ) + ) + self.convs1.apply(init_weights) + self.convs2.apply(init_weights) + self.activations1 = nn.ModuleList([ + Snake(channels, alpha_logscale=False) + for _ in range(len(self.convs1)) + ]) + self.activations2 = nn.ModuleList([ + Snake(channels, alpha_logscale=False) + for _ in range(len(self.convs2)) + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for idx in range(len(self.convs1)): + xt = self.activations1[idx](x) + xt = self.convs1[idx](xt) + xt = self.activations2[idx](xt) + xt = self.convs2[idx](xt) + x = xt + x + return x + + def remove_weight_norm(self): + for idx in range(len(self.convs1)): + remove_weight_norm(self.convs1[idx]) + remove_weight_norm(self.convs2[idx]) + + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + @torch.no_grad() + def forward(self, f0): + """ + :param f0: [B, 1, sample_len], Hz + :return: [B, 1, sample_len] + """ + + F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device) + for i in range(self.harmonic_num + 1): + F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate + + theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1) + u_dist = Uniform(low=-np.pi, high=np.pi) + phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device) + phase_vec[:, 0, :] = 0 + + # generate sine waveforms + sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec) + + # generate uv signal + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + with torch.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2)) + sine_wavs = sine_wavs.transpose(1, 2) + uv = uv.transpose(1, 2) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +class SineGen2(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, upsample_scale, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False): + super(SineGen2, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + self.upsample_scale = upsample_scale + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + def _f02sine(self, f0_values): + """ f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2), + scale_factor=1 / self.upsample_scale, + mode="linear").transpose(1, 2) + + phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi + phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale, + scale_factor=self.upsample_scale, mode="linear").transpose(1, 2) + sines = torch.sin(phase) + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + # fundamental component + fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) + + # generate sine waveforms + sine_waves = self._f02sine(fn) * self.sine_amp + + # generate uv signal + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF2(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF2, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + with torch.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv diff --git a/soulxpodcast/models/modules/sampler.py b/soulxpodcast/models/modules/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..28c48c71e0d10af04e39c9bed9bb63d9ae2ddd33 --- /dev/null +++ b/soulxpodcast/models/modules/sampler.py @@ -0,0 +1,221 @@ +import os + +from typing import Any, Callable, Optional, Union +import torch +from torch import nn +from transformers.generation.logits_process import ( + LogitsProcessorList +) +from transformers.generation.stopping_criteria import ( + StoppingCriteriaList +) +from transformers.generation.configuration_utils import ( + GenerationConfig +) +from transformers.generation.streamers import BaseStreamer +from transformers.generation.utils import ( + GenerateNonBeamOutput, + GenerateEncoderDecoderOutput, + GenerateDecoderOnlyOutput, +) +from transformers import StoppingCriteria + + +def _ras_sample_hf_engine( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool = False, + streamer: Optional["BaseStreamer"] = None, + use_ras=False, + win_size=25, + tau_r=0.2, + **model_kwargs, +) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size, cur_len = input_ids.shape[:2] + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) + + model_forward = self.__call__ + compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) + if compile_forward: + os.environ["TOKENIZERS_PARALLELISM"] = "0" + model_forward = self.get_compiled_call(generation_config.compile_config) + + if generation_config.prefill_chunk_size is not None: + model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs) + is_prefill = False + else: + is_prefill = True + + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + if is_prefill: + outputs = self(**model_inputs, return_dict=True) + is_prefill = False + else: + outputs = model_forward(**model_inputs, return_dict=True) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if synced_gpus and this_peer_finished: + continue + + # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device) + + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + + # Repetition Aware Sampling in VALL-E 2 + if use_ras: + probs_candidate = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens_candidate = torch.multinomial(probs_candidate, num_samples=1).squeeze(1) + rep_num = (input_ids[:,-win_size:] == next_tokens_candidate).sum().item() + 1 + if rep_num >= win_size * tau_r: + next_token_scores = next_token_logits + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # token selection + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + cur_len += 1 + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids \ No newline at end of file diff --git a/soulxpodcast/models/soulxpodcast.py b/soulxpodcast/models/soulxpodcast.py new file mode 100644 index 0000000000000000000000000000000000000000..96dbebae4f73aedee74ca36ba5a20999e607bb93 --- /dev/null +++ b/soulxpodcast/models/soulxpodcast.py @@ -0,0 +1,192 @@ +import time +from datetime import datetime +from itertools import chain +from tqdm import tqdm +from copy import deepcopy + +import numpy as np +import s3tokenizer +import torch + +from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache +from soulxpodcast.config import Config, SamplingParams, AutoPretrainedConfig +from soulxpodcast.engine.llm_engine import ( + HFLLMEngine, VLLMEngine +) +from soulxpodcast.models.modules.flow import CausalMaskedDiffWithXvec +from soulxpodcast.models.modules.hifigan import HiFTGenerator + +class SoulXPodcast(torch.nn.Module): + def __init__(self, config: Config = None): + super().__init__() + self.config = Config() if config is None else config + + self.audio_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").cuda().eval() + if self.config.llm_engine == "hf": + self.llm = HFLLMEngine(**self.config.__dict__) + elif self.config.llm_engine == "vllm": + self.llm = VLLMEngine(**self.config.__dict__) + else: + raise NotImplementedError + + self.use_tqdm = True + + self.flow = CausalMaskedDiffWithXvec() + if self.config.hf_config.fp16_flow: + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + tqdm.write(f"[{timestamp}] - [INFO] - Casting flow to fp16") + self.flow.half() + self.flow.load_state_dict(torch.load(f"{self.config.model}/flow.pt", map_location="cpu", weights_only=True), strict=True) + self.flow.cuda().eval() + + self.hift = HiFTGenerator() + hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{self.config.model}/hift.pt", map_location="cpu", weights_only=True).items()} + self.hift.load_state_dict(hift_state_dict, strict=True) + self.hift.cuda().eval() + + + @torch.inference_mode() + def forward_longform( + self, prompt_mels_for_llm, + prompt_mels_lens_for_llm: torch.Tensor, + prompt_text_tokens_for_llm: list[list[int]], + text_tokens_for_llm: list[list[int]], + prompt_mels_for_flow_ori, + spk_emb_for_flow: torch.Tensor, + sampling_params: SamplingParams | list[SamplingParams], + spk_ids: list[list[int]], + use_prompt_cot: bool = False, + prompt_cot_text_tokens_for_llm: list[list[int]] = None, + prompt_cot_prefix: list[list[int]] = None, + **kwargs, # for compatibility + ): + prompt_size, turn_size = len(prompt_mels_for_llm), len(text_tokens_for_llm) + + # Audio tokenization + prompt_speech_tokens_ori, prompt_speech_tokens_lens_ori = self.audio_tokenizer.quantize( + prompt_mels_for_llm.cuda(), prompt_mels_lens_for_llm.cuda() + ) + + # align speech token with speech feat as to reduce + # the noise ratio during the generation process. + prompt_speech_tokens = [] + prompt_mels_for_flow, prompt_mels_lens_for_flow = [], [] + + for prompt_index in range(prompt_size): + prompt_speech_token_len = prompt_speech_tokens_lens_ori[prompt_index].item() + prompt_speech_token = prompt_speech_tokens_ori[prompt_index, :prompt_speech_token_len] + prompt_mel = prompt_mels_for_flow_ori[prompt_index] + prompt_mel_len = prompt_mel.shape[0] + if prompt_speech_token_len * 2 > prompt_mel_len: + prompt_speech_token = prompt_speech_token[:int(prompt_mel_len/2)] + prompt_mel_len = torch.tensor([prompt_mel_len]).cuda() + else: + prompt_mel = prompt_mel.detach().clone()[:prompt_speech_token_len * 2].cuda() + prompt_mel_len = torch.tensor([prompt_speech_token_len * 2]).cuda() + prompt_speech_tokens.append(prompt_speech_token) + prompt_mels_for_flow.append(prompt_mel) + prompt_mels_lens_for_flow.append(prompt_mel_len) + + # Prepare LLM inputs + prompt_inputs = [] + history_inputs = [] + + # for i in range(prompt_size): + # prompt_mels = prompt_mels_for_flow[i][None] + # prompt_mels_lens = prompt_mels_lens_for_flow[i] + # spk_emb = spk_emb_for_flow[i:i+1] + + # # Flow generation + # with torch.amp.autocast("cuda", dtype=torch.float16 if self.config.hf_config.fp16_flow else torch.float32): + # flow_input = torch.concat([prompt_speech_tokens[i].detach().clone(), prompt_speech_tokens[1].detach().clone()], axis=0)[None] + # flow_inputs_len = torch.tensor(flow_input.shape[1])[None] + # generated_mels, generated_mels_lens = self.flow( + # flow_input.cuda(), flow_inputs_len.cuda(), + # prompt_mels, prompt_mels_lens, spk_emb.cuda(), + # streaming=False, finalize=True + # ) + + # # HiFi-GAN generation + # mel = generated_mels[:, :, prompt_mels_lens[0].item():generated_mels_lens[0].item()] + # wav, _ = self.hift(speech_feat=mel) + # import soundfile as sf + # sf.write(f"{str(i).zfill(2)}.wav", wav.cpu().squeeze(0).numpy(), 24000) + + for i in range(prompt_size): + speech_tokens_i = [token+self.config.hf_config.speech_token_offset for token in prompt_speech_tokens[i].tolist()] + speech_tokens_i += [self.config.hf_config.eos_token_id] + if use_prompt_cot and len(prompt_cot_text_tokens_for_llm[i])>0: + prompt_cot_input = prompt_text_tokens_for_llm[i] + speech_tokens_i + prompt_cot_text_tokens_for_llm[i] + if i>0: + prompt_cot_input = prompt_cot_prefix[0] + prompt_cot_input + cot_input = self.llm.generate(prompt_cot_input, sampling_params, past_key_values=None)['token_ids'] + prompt_inputs.append(prompt_cot_prefix[i+1]+prompt_cot_text_tokens_for_llm[i] + cot_input) + history_inputs.append(prompt_cot_prefix[i+1]+prompt_cot_text_tokens_for_llm[i] + cot_input) + else: + prompt_inputs.append(prompt_text_tokens_for_llm[i] + speech_tokens_i ) + history_inputs.append(prompt_text_tokens_for_llm[i] + speech_tokens_i ) + + generated_wavs, results_dict = [], {} + + # LLM generation + inputs = list(chain.from_iterable(prompt_inputs)) + cache_config = AutoPretrainedConfig().from_dataclass(self.llm.config.hf_config) + past_key_values = DynamicCache(config=cache_config) + valid_turn_size = prompt_size + for i in range(turn_size): + + # # set ratio: reach the reset cache ratio; + if valid_turn_size > self.config.max_turn_size or len(inputs)>self.config.turn_tokens_threshold: + assert self.config.max_turn_size >= self.config.prompt_context + self.config.history_context, "Invalid Long history size setting, " + prompt_text_bound = max(self.config.prompt_context, len(history_inputs)-self.config.history_text_context-self.config.history_context) + inputs = list(chain.from_iterable( + history_inputs[:self.config.prompt_context]+ \ + history_inputs[prompt_text_bound:-self.config.history_context]+ \ + prompt_inputs[-self.config.history_context:] + )) + valid_turn_size = self.config.prompt_context + len(history_inputs) - prompt_text_bound + past_key_values = DynamicCache(config=cache_config) + valid_turn_size += 1 + + inputs.extend(text_tokens_for_llm[i]) + start_time = time.time() + llm_outputs = self.llm.generate(inputs, sampling_params, past_key_values=past_key_values) + + inputs.extend(llm_outputs['token_ids']) + prompt_inputs.append(text_tokens_for_llm[i]+llm_outputs['token_ids']) + history_inputs.append(text_tokens_for_llm[i][:-1]) # remove the <|audio_start|> + + # Prepare Flow inputs + turn_spk = spk_ids[i] + generated_speech_tokens = [token - self.config.hf_config.speech_token_offset for token in llm_outputs['token_ids'][:-1]] # ignore last eos + prompt_speech_token = prompt_speech_tokens[turn_spk].tolist() + flow_input = torch.tensor([prompt_speech_token + generated_speech_tokens]) + flow_inputs_len = torch.tensor([len(prompt_speech_token) + len(generated_speech_tokens)]) + + + # Flow generation and HiFi-GAN generation + start_idx = spk_ids[i] + prompt_mels = prompt_mels_for_flow[start_idx][None] + prompt_mels_lens = prompt_mels_lens_for_flow[start_idx][None] + spk_emb = spk_emb_for_flow[start_idx:start_idx+1] + + # Flow generation + with torch.amp.autocast("cuda", dtype=torch.float16 if self.config.hf_config.fp16_flow else torch.float32): + generated_mels, generated_mels_lens = self.flow( + flow_input.cuda(), flow_inputs_len.cuda(), + prompt_mels, prompt_mels_lens, spk_emb.cuda(), + streaming=False, finalize=True + ) + + # HiFi-GAN generation + mel = generated_mels[:, :, prompt_mels_lens[0].item():generated_mels_lens[0].item()] + try: + wav, _ = self.hift(speech_feat=mel) + except Exception as e: + import pdb;pdb.set_trace() + print(e) + generated_wavs.append(wav) + + # Save the generated wav; + results_dict['generated_wavs'] = generated_wavs + return results_dict \ No newline at end of file diff --git a/soulxpodcast/utils/__init__.py b/soulxpodcast/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/soulxpodcast/utils/__pycache__/__init__.cpython-311.pyc b/soulxpodcast/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a217a056fd34d04039728c6fab61154737eabd6 Binary files /dev/null and b/soulxpodcast/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/soulxpodcast/utils/__pycache__/__init__.cpython-312.pyc b/soulxpodcast/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce131b4d26e4fb140382fe850d529df59908aa7a Binary files /dev/null and b/soulxpodcast/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/soulxpodcast/utils/__pycache__/audio.cpython-311.pyc b/soulxpodcast/utils/__pycache__/audio.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c193bad769ca567572132b7ad809c12eaae4038 Binary files /dev/null and b/soulxpodcast/utils/__pycache__/audio.cpython-311.pyc differ diff --git a/soulxpodcast/utils/__pycache__/audio.cpython-312.pyc b/soulxpodcast/utils/__pycache__/audio.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d62e88b22bd174ba00a61ad52cea27ebef8a4a0 Binary files /dev/null and b/soulxpodcast/utils/__pycache__/audio.cpython-312.pyc differ diff --git a/soulxpodcast/utils/__pycache__/commons.cpython-311.pyc b/soulxpodcast/utils/__pycache__/commons.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3df83532d581747b9b100ecc078978351589b61 Binary files /dev/null and b/soulxpodcast/utils/__pycache__/commons.cpython-311.pyc differ diff --git a/soulxpodcast/utils/__pycache__/dataloader.cpython-311.pyc b/soulxpodcast/utils/__pycache__/dataloader.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14b09ea06964e96e3cff37a5fad5f47553198340 Binary files /dev/null and b/soulxpodcast/utils/__pycache__/dataloader.cpython-311.pyc differ diff --git a/soulxpodcast/utils/__pycache__/dataloader.cpython-312.pyc b/soulxpodcast/utils/__pycache__/dataloader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc0d9344e32ad272e91c08ff85cd5317ed6f663f Binary files /dev/null and b/soulxpodcast/utils/__pycache__/dataloader.cpython-312.pyc differ diff --git a/soulxpodcast/utils/__pycache__/infer_utils.cpython-311.pyc b/soulxpodcast/utils/__pycache__/infer_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ab3ce2a6b9a87434ed406b2c80d4fcc34f41d40 Binary files /dev/null and b/soulxpodcast/utils/__pycache__/infer_utils.cpython-311.pyc differ diff --git a/soulxpodcast/utils/__pycache__/parser.cpython-311.pyc b/soulxpodcast/utils/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7922018e42cb15020e61108f1ffe66f9632b6392 Binary files /dev/null and b/soulxpodcast/utils/__pycache__/parser.cpython-311.pyc differ diff --git a/soulxpodcast/utils/__pycache__/text.cpython-311.pyc b/soulxpodcast/utils/__pycache__/text.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3b53fcac97808f239e8b5f022c75812ba609aea Binary files /dev/null and b/soulxpodcast/utils/__pycache__/text.cpython-311.pyc differ diff --git a/soulxpodcast/utils/__pycache__/text.cpython-312.pyc b/soulxpodcast/utils/__pycache__/text.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b9ff2304c9fe193143469020244016df315c2ab Binary files /dev/null and b/soulxpodcast/utils/__pycache__/text.cpython-312.pyc differ diff --git a/soulxpodcast/utils/audio.py b/soulxpodcast/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a81447fd77ca1518513d6ae09ea21db570f465 --- /dev/null +++ b/soulxpodcast/utils/audio.py @@ -0,0 +1,77 @@ +import numpy as np +import torch +from librosa.filters import mel as librosa_mel_fn +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, + win_size=1920, fmin=0, fmax=8000, center=False): + global mel_basis, hann_window # pylint: disable=global-statement + if f"{str(fmax)}_{str(y.device)}" not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/soulxpodcast/utils/dataloader.py b/soulxpodcast/utils/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..d4910cb0b705c6b16bd8cb180e6832591729b65e --- /dev/null +++ b/soulxpodcast/utils/dataloader.py @@ -0,0 +1,194 @@ +import os +import json +from tqdm import tqdm +from datetime import datetime + +import onnxruntime +import torch +import torchaudio +from torch.utils.data import DataLoader, Dataset, DistributedSampler +import torchaudio.compliance.kaldi as kaldi + +import s3tokenizer + +from soulxpodcast.utils.audio import mel_spectrogram +from soulxpodcast.utils.text import normalize_text +from soulxpodcast.config import Config, SamplingParams + + +SPK_DICT = ["<|SPEAKER_0|>", "<|SPEAKER_1|>", "<|SPEAKER_2|>", "<|SPEAKER_3|>",] +TEXT_START, TEXT_END, AUDIO_START = "<|text_start|>", "<|text_end|>", "<|semantic_token_start|>" +TASK_PODCAST = "<|task_podcast|>" + + +class PodcastDataset(Dataset): + + def __init__(self, text_tokenizer, data_list, model_config: Config): + self.datas = [] + self.model_config = model_config + + """Example data_list: + ``` + {"key": "uttid_1", "prompt_text": ["prompt_text1", "prompt_text2"], "prompt_cot_text": ["prompt_cot_text1", "prompt_cot_text2"], + "text": ["text1", "text2], "spk": [0, 1], "prompt_wav": ["/mnt/data/audio/00000000.wav", "/mnt/data/audio/00000001.wav"], "wav": "/mnt/data/audio_synthetic/uttid_1.wav"} + ``` + Note: + - `key` is the key of this sample. + - `prompt_text` is the text used for prompt. + - `prompt_cot_text` is the reshot text used for prompt. + - `text` is the text used for generating real audio. + - `spk` is the target speaker id to synthesize, corresponds to the prompt order. Default SPEAKER_0. + - `prompt_wav` is the audio used for prompt. + - `wav` is the path to the generated audio to be saved (we highly recommend to pre-define the save path before running the script). + """ + missing = 0 + with open(data_list, 'r', encoding='utf-8') as f: + lines = f.readlines() + total_lines = len(lines) + iterator = tqdm(lines, desc='Loading data') + for line in iterator: + data = json.loads(line.strip()) + valid = True + for k in ['key', 'prompt_text', 'text', 'prompt_wav']: + if k not in data: + valid = False + break + if data[k] is None: + valid = False + break + valid = True + for url in data["prompt_wav"]: + if not os.path.exists(url): + valid = False + break + if valid: + self.datas.append(data) + else: + missing += 1 + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + tqdm.write(f'[{timestamp}] - [INFO] - Loaded {total_lines} lines, found {missing} missing lines, total valid lines == {len(self.datas)}.') + + self.text_tokenizer = text_tokenizer + + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + self.spk_model = onnxruntime.InferenceSession(f"{self.model_config.model}/campplus.onnx", sess_options=option, + providers=["CPUExecutionProvider"]) + + def __len__(self): + return len(self.datas) + + def __getitem__(self, idx): + data = self.datas[idx] + try: + prompt_text_ids_list, prompt_cot_text_ids_list, spk_emb_list, mel_list, mel_len_list, log_mel_list = ( + [], [], [], [], [], [] + ) + # Prepare prompt information + use_prompt_cot = "prompt_cot_text" in data + prompt_cot_prefix_list = [] + prompt_cot_prefix_list.append(self.text_tokenizer.encode(f"{TASK_PODCAST}")) + for spk_idx, (prompt_text, prompt_wav) in enumerate(zip(data["prompt_text"], data["prompt_wav"])): + # 1. feature for s3tokenizer + audio = s3tokenizer.load_audio(prompt_wav, sr=16000) # [T] + log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T] + + # 2. feature for speaker embedding + spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) + spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) + spk_emb = self.spk_model.run( + None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} + )[0].flatten().tolist() + + # 3. feature for flow + audio, sample_rate = torchaudio.load(prompt_wav, backend='soundfile') + audio = audio.mean(dim=0, keepdim=True) # [1, T] + if sample_rate != 24000: + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio) + mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels] + if mel.shape[0] %2 !=0: + mel = mel[:-1] + mel_len = mel.shape[0] + + # 4. feature for llm + prompt_text = normalize_text(prompt_text) # remove some space and strange character + prompt_text = f"{SPK_DICT[spk_idx]}{TEXT_START}{prompt_text}{TEXT_END}{AUDIO_START}" + if spk_idx == 0: + prompt_text = f"{TASK_PODCAST}{prompt_text}" + prompt_text_ids = self.text_tokenizer.encode(prompt_text) + prompt_text_ids_list.append(prompt_text_ids) + if use_prompt_cot: + prompt_cot_text = normalize_text(data["prompt_cot_text"][spk_idx]) + prompt_cot_text = f"{SPK_DICT[spk_idx]}{TEXT_START}{prompt_cot_text}{TEXT_END}{AUDIO_START}" + prompt_cot_text_ids = self.text_tokenizer.encode(prompt_cot_text) + prompt_cot_text_ids_list.append(prompt_cot_text_ids) + if spk_idx == 0: + prompt_cot_prefix_list.append(self.text_tokenizer.encode(f"{TASK_PODCAST}")) + else: + prompt_cot_prefix_list.append([]) + log_mel_list.append(log_mel) + spk_emb_list.append(spk_emb) + mel_list.append(mel); mel_len_list.append(mel_len) + item = { + "prompt_text_tokens": prompt_text_ids_list, + "spk_emb": spk_emb_list, "mel": mel_list, "mel_len": mel_len_list, "log_mel": log_mel_list, "info": data, + } + if use_prompt_cot: + item.update({ + "use_prompt_cot": True, + "prompt_cot_text_tokens": prompt_cot_text_ids_list, + "prompt_cot_prefix": prompt_cot_prefix_list, + }) + text_ids_list, spks_list = [], [] + if "spk" not in data: + data["spk"] = [0] * len(data["text"]) + + for text, spk in zip(data["text"], data["spk"]): + # 4. feature for llm + text = normalize_text(text) + text = f"{SPK_DICT[spk]}{TEXT_START}{text}{TEXT_END}{AUDIO_START}" + text_ids = self.text_tokenizer.encode(text) + + text_ids_list.append(text_ids) + spks_list.append(spk) + + item.update({ + "text_tokens": text_ids_list, "spks_list": spks_list, + }) + except Exception as e: + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + tqdm.write(f"[{timestamp}] - [WARNING] - Error processing data item {data.get('key', idx)}: {e}") + return None + return item + +class PodcastInferHandler(PodcastDataset): + + def __init__(self, text_tokenizer, data_list, model_config: Config): + self.datas = [] + self.model_config = model_config + + """Example data_list: + ``` + {"key": "uttid_1", "prompt_text": ["prompt_text1", "prompt_text2"], "prompt_cot_text": ["prompt_cot_text1", "prompt_cot_text2"], "text": ["text1", "text2], "spk": [0, 1], "prompt_wav": ["/mnt/data/audio/00000000.wav", "/mnt/data/audio/00000001.wav"], "wav": "/mnt/data/audio_synthetic/uttid_1.wav"} + ``` + Note: + - `key` is the key of this sample. + - `prompt_text` is the text used for prompt. + - `prompt_cot_text` is the cot text used for prompt as to activate specific ability. + - `text` is the text used for generating real audio. + - `spk` is the target speaker id to synthesize, corresponds to the prompt order. Default SPEAKER_0. + - `prompt_wav` is the audio used for prompt. + - `wav` is the path to the generated audio to be saved (we highly recommend to pre-define the save path before running the script). + """ + missing = 0 + self.text_tokenizer = text_tokenizer + + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + self.spk_model = onnxruntime.InferenceSession(f"{self.model_config.model}/campplus.onnx", sess_options=option, + providers=["CPUExecutionProvider"]) + + def update_datasource(self, data_list): + self.datas = data_list \ No newline at end of file diff --git a/soulxpodcast/utils/text.py b/soulxpodcast/utils/text.py new file mode 100644 index 0000000000000000000000000000000000000000..825da1cde7a610395df6653cac1d343e8a955171 --- /dev/null +++ b/soulxpodcast/utils/text.py @@ -0,0 +1,43 @@ +import re + +# 删除连续中文字符之间的空格,不删除英文和中文之间的空格 +def remove_space_between_chinese(text): + # 删除连续中文字符之间的空格 + text = re.sub(r'([\u4e00-\u9fff])\s+([\u4e00-\u9fff])', r'\1\2', text) + + # 删除英文和中文接续处的空格 (英文和中文之间的空格) + text = re.sub(r'([a-zA-Z])\s+([\u4e00-\u9fff])', r'\1\2', text) + + # 删除中文和英文接续处的空格 (中文和英文之间的空格) + text = re.sub(r'([\u4e00-\u9fff])\s+([a-zA-Z])', r'\1\2', text) + + return text + +# 判断文本结尾是中文还是英文并添加适当的标点 +def normalize_text(current_text): + # keep_punctuation=',。?!.,?!<| |>' + # pattern = f'[\\p{{P}}--[{keep_punctuation}]]' + # current_text = re.sub(pattern, '', current_text) + + # 删除连续中文字符之间的空格 + current_text = re.sub(r'([\u4e00-\u9fff])\s+([\u4e00-\u9fff])', r'\1\2', current_text) + + # 删除英文和中文接续处的空格 (英文和中文之间的空格) + current_text = re.sub(r'([a-zA-Z])\s+([\u4e00-\u9fff])', r'\1\2', current_text) + + # 删除中文和英文接续处的空格 (中文和英文之间的空格) + current_text = re.sub(r'([\u4e00-\u9fff])\s+([a-zA-Z])', r'\1\2', current_text) + + # 判断结尾字符是否为中文字符 + if re.search(r'[\u4e00-\u9fff]$', current_text): # 中文字符结尾 + # 如果结尾不是句号、感叹号、问号,则添加句号 + if current_text[-1] not in ",.?!。,?!": + current_text += "。" + + # 判断结尾字符是否为英文字符 + elif re.search(r'[a-zA-Z]$', current_text): # 英文字符结尾 + # 如果结尾不是句号、感叹号、问号,则添加句号 + if current_text[-1] not in ".!?": + current_text += "." + + return current_text