Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	test
Browse files- app.py +318 -0
 - diffrhythm/config/defaults.ini +94 -0
 - diffrhythm/config/diffrhythm-1b.json +13 -0
 - diffrhythm/model/__init__.py +6 -0
 - diffrhythm/model/__pycache__/__init__.cpython-310.pyc +0 -0
 - diffrhythm/model/__pycache__/__init__.cpython-312.pyc +0 -0
 - diffrhythm/model/__pycache__/cfm.cpython-310.pyc +0 -0
 - diffrhythm/model/__pycache__/cfm.cpython-312.pyc +0 -0
 - diffrhythm/model/__pycache__/custom_dataset.cpython-310.pyc +0 -0
 - diffrhythm/model/__pycache__/custom_dataset_lrc_emb.cpython-310.pyc +0 -0
 - diffrhythm/model/__pycache__/dataset.cpython-310.pyc +0 -0
 - diffrhythm/model/__pycache__/dit.cpython-310.pyc +0 -0
 - diffrhythm/model/__pycache__/modules.cpython-310.pyc +0 -0
 - diffrhythm/model/__pycache__/trainer.cpython-310.pyc +0 -0
 - diffrhythm/model/__pycache__/utils.cpython-310.pyc +0 -0
 - diffrhythm/model/cfm.py +315 -0
 - diffrhythm/model/dit.py +221 -0
 - diffrhythm/model/modules.py +611 -0
 - diffrhythm/model/trainer.py +350 -0
 - diffrhythm/model/utils.py +182 -0
 - prompt/gift_of_the_world.wav +0 -0
 - prompt/little_happiness.wav +0 -0
 - prompt/little_talks.wav +0 -0
 - prompt/ltwyl.wav +0 -0
 - prompt/most_beautiful_expectation.wav +0 -0
 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,318 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import gradio as gr
         
     | 
| 2 | 
         
            +
            from openai import OpenAI
         
     | 
| 3 | 
         
            +
            import requests
         
     | 
| 4 | 
         
            +
            import json
         
     | 
| 5 | 
         
            +
            # from volcenginesdkarkruntime import Ark
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            import torchaudio
         
     | 
| 8 | 
         
            +
            from einops import rearrange
         
     | 
| 9 | 
         
            +
            import argparse
         
     | 
| 10 | 
         
            +
            import json
         
     | 
| 11 | 
         
            +
            import os
         
     | 
| 12 | 
         
            +
            from tqdm import tqdm
         
     | 
| 13 | 
         
            +
            import random
         
     | 
| 14 | 
         
            +
            import numpy as np
         
     | 
| 15 | 
         
            +
            import sys
         
     | 
| 16 | 
         
            +
            from diffrhythm.infer.infer_utils import (
         
     | 
| 17 | 
         
            +
                get_reference_latent,
         
     | 
| 18 | 
         
            +
                get_lrc_token,
         
     | 
| 19 | 
         
            +
                get_style_prompt,
         
     | 
| 20 | 
         
            +
                prepare_model,
         
     | 
| 21 | 
         
            +
                get_negative_style_prompt
         
     | 
| 22 | 
         
            +
            )
         
     | 
| 23 | 
         
            +
            from diffrhythm.infer.infer import inference
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            device='cuda'
         
     | 
| 26 | 
         
            +
            cfm, tokenizer, muq, vae = prepare_model(device)
         
     | 
| 27 | 
         
            +
            cfm = torch.compile(cfm)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            def infer_music(lrc, ref_audio_path, max_frames=2048, device='cuda'):
         
     | 
| 30 | 
         
            +
                
         
     | 
| 31 | 
         
            +
                # lrc_list = lrc.split("\n")
         
     | 
| 32 | 
         
            +
                # print(lrc_list)
         
     | 
| 33 | 
         
            +
                
         
     | 
| 34 | 
         
            +
                # return "./gift_of_the_world.wav"
         
     | 
| 35 | 
         
            +
                lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
         
     | 
| 36 | 
         
            +
                style_prompt = get_style_prompt(muq, ref_audio_path)
         
     | 
| 37 | 
         
            +
                negative_style_prompt = get_negative_style_prompt(device)
         
     | 
| 38 | 
         
            +
                latent_prompt = get_reference_latent(device, max_frames)
         
     | 
| 39 | 
         
            +
                generated_song = inference(cfm_model=cfm, 
         
     | 
| 40 | 
         
            +
                                           vae_model=vae, 
         
     | 
| 41 | 
         
            +
                                           cond=latent_prompt, 
         
     | 
| 42 | 
         
            +
                                           text=lrc_prompt, 
         
     | 
| 43 | 
         
            +
                                           duration=max_frames, 
         
     | 
| 44 | 
         
            +
                                           style_prompt=style_prompt,
         
     | 
| 45 | 
         
            +
                                           negative_style_prompt=negative_style_prompt,
         
     | 
| 46 | 
         
            +
                                           start_time=start_time
         
     | 
| 47 | 
         
            +
                                           )
         
     | 
| 48 | 
         
            +
                return generated_song
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            def R1_infer1(theme, tags_gen, language):
         
     | 
| 51 | 
         
            +
                try:
         
     | 
| 52 | 
         
            +
                    client = OpenAI(api_key="XXXX", base_url = "https://ark.cn-beijing.volces.com/api/v3")
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    llm_prompt = """
         
     | 
| 55 | 
         
            +
                    请围绕"{theme}"主题生成一首符合"{tags}"风格的完整歌词。生成的{language}语言的歌词。
         
     | 
| 56 | 
         
            +
                    ### **歌曲结构要求**
         
     | 
| 57 | 
         
            +
                    1. 歌词应富有变化,使情绪递进,整体连贯有层次感。**每行歌词长度应自然变化**,切勿长度一致,导致很格式化。
         
     | 
| 58 | 
         
            +
                    2. **时间戳分配应根据歌曲的标签\歌词的情感、节奏来合理推测**,而非机械地按照歌词长度分配。 
         
     | 
| 59 | 
         
            +
                    ### **歌曲内容要求**
         
     | 
| 60 | 
         
            +
                    1. **第一句歌词的时间戳应考虑前奏长度**,避免歌词从 `[00:00.00]` 直接开始。
         
     | 
| 61 | 
         
            +
                    2. **严格按照 LRC 格式输出歌词**,每行格式为 `[mm:ss.xx]歌词内容`。
         
     | 
| 62 | 
         
            +
                    3. 输出的歌词不能有空行、括号,不能有其他解释内容,例如:副歌、桥段、结尾。  
         
     | 
| 63 | 
         
            +
                    4. 输出必须是**纯净的 LRC**。
         
     | 
| 64 | 
         
            +
                    """
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    response = client.chat.completions.create(
         
     | 
| 67 | 
         
            +
                        model="ep-20250215195652-lrff7",
         
     | 
| 68 | 
         
            +
                        messages=[
         
     | 
| 69 | 
         
            +
                            {"role": "system", "content": "You are a professional musician who has been invited to make music-related comments."},
         
     | 
| 70 | 
         
            +
                            {"role": "user", "content": llm_prompt.format(theme=theme, tags=tags_gen, language=language)},
         
     | 
| 71 | 
         
            +
                        ],
         
     | 
| 72 | 
         
            +
                        stream=False
         
     | 
| 73 | 
         
            +
                    )
         
     | 
| 74 | 
         
            +
                    
         
     | 
| 75 | 
         
            +
                    info = response.choices[0].message.content
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    return info
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                except requests.exceptions.RequestException as e:
         
     | 
| 80 | 
         
            +
                    print(f'请求出错: {e}')
         
     | 
| 81 | 
         
            +
                    return {}
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            def R1_infer2(tags_lyrics, lyrics_input):
         
     | 
| 86 | 
         
            +
                client = OpenAI(api_key="XXX", base_url = "https://ark.cn-beijing.volces.com/api/v3")
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                llm_prompt = """
         
     | 
| 89 | 
         
            +
                {lyrics_input}这是一首歌的歌词,每一行是一句歌词,{tags_lyrics}是我希望这首歌的风格,我现在想要给这首歌的每一句歌词打时间戳得到LRC,我希望时间戳分配应根据歌曲的标签、歌词的情感、节奏来合理推测,而非机械地按照歌词长度分配。第一句歌词的时间戳应考虑前奏长度,避免歌词从 `[00:00.00]` 直接开始。严格按照 LRC 格式输出歌词,每行格式为 `[mm:ss.xx]歌词内容`。最后的结果只输出LRC,不需要其他的解释。
         
     | 
| 90 | 
         
            +
                """
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                response = client.chat.completions.create(
         
     | 
| 93 | 
         
            +
                    model="ep-20250215195652-lrff7",
         
     | 
| 94 | 
         
            +
                    messages=[
         
     | 
| 95 | 
         
            +
                        {"role": "system", "content": "You are a professional musician who has been invited to make music-related comments."},
         
     | 
| 96 | 
         
            +
                        {"role": "user", "content": llm_prompt.format(lyrics_input=lyrics_input, tags_lyrics=tags_lyrics)},
         
     | 
| 97 | 
         
            +
                    ],
         
     | 
| 98 | 
         
            +
                    stream=False
         
     | 
| 99 | 
         
            +
                )
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                info = response.choices[0].message.content
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                return info
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
            css = """
         
     | 
| 106 | 
         
            +
            /* 固定文本域高度并强制滚动条 */
         
     | 
| 107 | 
         
            +
            .lyrics-scroll-box textarea {
         
     | 
| 108 | 
         
            +
                height: 300px !important;  /* 固定高度 */
         
     | 
| 109 | 
         
            +
                max-height: 500px !important;  /* 最大高度 */
         
     | 
| 110 | 
         
            +
                overflow-y: auto !important;  /* 垂直滚动 */
         
     | 
| 111 | 
         
            +
                white-space: pre-wrap;  /* 保留换行 */
         
     | 
| 112 | 
         
            +
                line-height: 1.5;  /* 行高优化 */
         
     | 
| 113 | 
         
            +
            }
         
     | 
| 114 | 
         
            +
            """
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
            with gr.Blocks(css=css) as demo:
         
     | 
| 117 | 
         
            +
                gr.Markdown("# DiffRhythm")
         
     | 
| 118 | 
         
            +
                
         
     | 
| 119 | 
         
            +
                with gr.Tabs() as tabs:
         
     | 
| 120 | 
         
            +
                    
         
     | 
| 121 | 
         
            +
                    # page 1
         
     | 
| 122 | 
         
            +
                    with gr.Tab("Music Generate", id=0):
         
     | 
| 123 | 
         
            +
                        with gr.Row():
         
     | 
| 124 | 
         
            +
                            with gr.Column():
         
     | 
| 125 | 
         
            +
                                with gr.Accordion("Best Practices Guide", open=False):
         
     | 
| 126 | 
         
            +
                                    gr.Markdown("""
         
     | 
| 127 | 
         
            +
                                    1. **Lyrics Format Requirements**
         
     | 
| 128 | 
         
            +
                                    - Each line must follow: `[mm:ss.xx]Lyric content`
         
     | 
| 129 | 
         
            +
                                    - Example of valid format:
         
     | 
| 130 | 
         
            +
                                        ``` 
         
     | 
| 131 | 
         
            +
                                        [00:07.23]Fight me fight me fight me
         
     | 
| 132 | 
         
            +
                                        [00:08.73]You made me so unlike me
         
     | 
| 133 | 
         
            +
                                        ```
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                                    2. **Generation Duration Limits**
         
     | 
| 136 | 
         
            +
                                    - Current version supports maximum **95 seconds** of music generation
         
     | 
| 137 | 
         
            +
                                    - Total timestamps should not exceed 01:35.00 (95 seconds)
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                                    3. **Audio Prompt Requirements**
         
     | 
| 140 | 
         
            +
                                    - Reference audio should be ≥10 seconds for optimal results
         
     | 
| 141 | 
         
            +
                                    - Shorter clips may lead to incoherent generation
         
     | 
| 142 | 
         
            +
                                    """)
         
     | 
| 143 | 
         
            +
                                lrc = gr.Textbox(
         
     | 
| 144 | 
         
            +
                                    label="Lrc",
         
     | 
| 145 | 
         
            +
                                    placeholder="Input the full lyrics",
         
     | 
| 146 | 
         
            +
                                    lines=12,
         
     | 
| 147 | 
         
            +
                                    max_lines=50,
         
     | 
| 148 | 
         
            +
                                    elem_classes="lyrics-scroll-box"
         
     | 
| 149 | 
         
            +
                                )
         
     | 
| 150 | 
         
            +
                                audio_prompt = gr.Audio(label="Audio Prompt", type="filepath")
         
     | 
| 151 | 
         
            +
                                
         
     | 
| 152 | 
         
            +
                            with gr.Column():
         
     | 
| 153 | 
         
            +
                                lyrics_btn = gr.Button("Submit", variant="primary")
         
     | 
| 154 | 
         
            +
                                audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
         
     | 
| 155 | 
         
            +
                                
         
     | 
| 156 | 
         
            +
                        
         
     | 
| 157 | 
         
            +
                        gr.Examples(
         
     | 
| 158 | 
         
            +
                            examples=[
         
     | 
| 159 | 
         
            +
                                ["./gift_of_the_world.wav"], 
         
     | 
| 160 | 
         
            +
                                ["./most_beautiful_expectation.wav"],
         
     | 
| 161 | 
         
            +
                                ["./ltwyl.wav"]
         
     | 
| 162 | 
         
            +
                            ],
         
     | 
| 163 | 
         
            +
                            inputs=[audio_prompt],  
         
     | 
| 164 | 
         
            +
                            label="Audio Examples",
         
     | 
| 165 | 
         
            +
                            examples_per_page=3
         
     | 
| 166 | 
         
            +
                        )
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                        gr.Examples(
         
     | 
| 169 | 
         
            +
                            examples=[
         
     | 
| 170 | 
         
            +
                                ["""[00:10.00]Moonlight spills through broken blinds
         
     | 
| 171 | 
         
            +
            [00:13.20]Your shadow dances on the dashboard shrine
         
     | 
| 172 | 
         
            +
            [00:16.85]Neon ghosts in gasoline rain
         
     | 
| 173 | 
         
            +
            [00:20.40]I hear your laughter down the midnight train
         
     | 
| 174 | 
         
            +
            [00:24.15]Static whispers through frayed wires
         
     | 
| 175 | 
         
            +
            [00:27.65]Guitar strings hum our cathedral choirs
         
     | 
| 176 | 
         
            +
            [00:31.30]Flicker screens show reruns of June
         
     | 
| 177 | 
         
            +
            [00:34.90]I'm drowning in this mercury lagoon
         
     | 
| 178 | 
         
            +
            [00:38.55]Electric veins pulse through concrete skies
         
     | 
| 179 | 
         
            +
            [00:42.10]Your name echoes in the hollow where my heartbeat lies
         
     | 
| 180 | 
         
            +
            [00:45.75]We're satellites trapped in parallel light
         
     | 
| 181 | 
         
            +
            [00:49.25]Burning through the atmosphere of endless night
         
     | 
| 182 | 
         
            +
            [01:00.00]Dusty vinyl spins reverse
         
     | 
| 183 | 
         
            +
            [01:03.45]Our polaroid timeline bleeds through the verse
         
     | 
| 184 | 
         
            +
            [01:07.10]Telescope aimed at dead stars
         
     | 
| 185 | 
         
            +
            [01:10.65]Still tracing constellations through prison bars
         
     | 
| 186 | 
         
            +
            [01:14.30]Electric veins pulse through concrete skies
         
     | 
| 187 | 
         
            +
            [01:17.85]Your name echoes in the hollow where my heartbeat lies
         
     | 
| 188 | 
         
            +
            [01:21.50]We're satellites trapped in parallel light
         
     | 
| 189 | 
         
            +
            [01:25.05]Burning through the atmosphere of endless night
         
     | 
| 190 | 
         
            +
            [02:10.00]Clockwork gears grind moonbeams to rust
         
     | 
| 191 | 
         
            +
            [02:13.50]Our fingerprint smudged by interstellar dust
         
     | 
| 192 | 
         
            +
            [02:17.15]Velvet thunder rolls through my veins
         
     | 
| 193 | 
         
            +
            [02:20.70]Chasing phantom trains through solar plane
         
     | 
| 194 | 
         
            +
            [02:24.35]Electric veins pulse through concrete skies
         
     | 
| 195 | 
         
            +
            [02:27.90]Your name echoes in the hollow where my heartbeat lies"""],
         
     | 
| 196 | 
         
            +
                            ["""[00:05.00]Stardust whispers in your eyes
         
     | 
| 197 | 
         
            +
            [00:09.30]Moonlight paints our silhouettes
         
     | 
| 198 | 
         
            +
            [00:13.75]Tides bring secrets from the deep
         
     | 
| 199 | 
         
            +
            [00:18.20]Where forever's breath is kept
         
     | 
| 200 | 
         
            +
            [00:22.90]We dance through constellations' maze
         
     | 
| 201 | 
         
            +
            [00:27.15]Footprints melt in cosmic waves
         
     | 
| 202 | 
         
            +
            [00:31.65]Horizons hum our silent vow
         
     | 
| 203 | 
         
            +
            [00:36.10]Time unravels here and now
         
     | 
| 204 | 
         
            +
            [00:40.85]Eternal embers in the night oh oh oh
         
     | 
| 205 | 
         
            +
            [00:45.25]Healing scars with liquid light
         
     | 
| 206 | 
         
            +
            [00:49.70]Galaxies write our refrain
         
     | 
| 207 | 
         
            +
            [00:54.15]Love reborn in endless rain
         
     | 
| 208 | 
         
            +
            [01:15.30]Paper boats of memories
         
     | 
| 209 | 
         
            +
            [01:19.75]Float through veins of ancient trees
         
     | 
| 210 | 
         
            +
            [01:24.20]Your laughter spins aurora threads
         
     | 
| 211 | 
         
            +
            [01:28.65]Weaving dawn through featherbed"""]
         
     | 
| 212 | 
         
            +
                            ],
         
     | 
| 213 | 
         
            +
                            inputs=[lrc],  # 只绑定到歌词输入
         
     | 
| 214 | 
         
            +
                            label="Lrc Examples",
         
     | 
| 215 | 
         
            +
                            examples_per_page=2
         
     | 
| 216 | 
         
            +
                        )
         
     | 
| 217 | 
         
            +
              
         
     | 
| 218 | 
         
            +
                    # page 2
         
     | 
| 219 | 
         
            +
                    with gr.Tab("LLM Generate LRC", id=1):
         
     | 
| 220 | 
         
            +
                        with gr.Row():
         
     | 
| 221 | 
         
            +
                            with gr.Column():
         
     | 
| 222 | 
         
            +
                                with gr.Accordion("Notice", open=False):
         
     | 
| 223 | 
         
            +
                                    gr.Markdown("**Two Generation Modes:**\n1. Generate from theme & tags\n2. Add timestamps to existing lyrics")
         
     | 
| 224 | 
         
            +
                                
         
     | 
| 225 | 
         
            +
                                with gr.Group():
         
     | 
| 226 | 
         
            +
                                    gr.Markdown("### Method 1: Generate from Theme")
         
     | 
| 227 | 
         
            +
                                    theme = gr.Textbox(label="theme", placeholder="Enter song theme, e.g. Love and Heartbreak")
         
     | 
| 228 | 
         
            +
                                    tags_gen = gr.Textbox(label="tags", placeholder="Example: male pop confidence healing")
         
     | 
| 229 | 
         
            +
                                    language = gr.Dropdown(["zh", "en"], label="language", value="en")
         
     | 
| 230 | 
         
            +
                                    gen_from_theme_btn = gr.Button("Generate LRC (From Theme)", variant="primary")
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                                with gr.Group(visible=True): 
         
     | 
| 233 | 
         
            +
                                    gr.Markdown("### Method 2: Add Timestamps to Lyrics")
         
     | 
| 234 | 
         
            +
                                    tags_lyrics = gr.Textbox(label="tags", placeholder="Example: female ballad piano slow")
         
     | 
| 235 | 
         
            +
                                    lyrics_input = gr.Textbox(
         
     | 
| 236 | 
         
            +
                                        label="Raw Lyrics (without timestamps)",
         
     | 
| 237 | 
         
            +
                                        placeholder="Enter plain lyrics (without timestamps), e.g.:\nYesterday\nAll my troubles...",
         
     | 
| 238 | 
         
            +
                                        lines=12,
         
     | 
| 239 | 
         
            +
                                        max_lines=50,
         
     | 
| 240 | 
         
            +
                                        elem_classes="lyrics-scroll-box"
         
     | 
| 241 | 
         
            +
                                    )
         
     | 
| 242 | 
         
            +
                                    gen_from_lyrics_btn = gr.Button("Generate LRC (From Lyrics)", variant="primary")
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                            with gr.Column():
         
     | 
| 245 | 
         
            +
                                lrc_output = gr.Textbox(
         
     | 
| 246 | 
         
            +
                                    label="Generated LRC Lyrics",
         
     | 
| 247 | 
         
            +
                                    placeholder="Timed lyrics will appear here",
         
     | 
| 248 | 
         
            +
                                    lines=50,
         
     | 
| 249 | 
         
            +
                                    elem_classes="lrc-output",
         
     | 
| 250 | 
         
            +
                                    show_copy_button=True
         
     | 
| 251 | 
         
            +
                                )
         
     | 
| 252 | 
         
            +
                                
         
     | 
| 253 | 
         
            +
                        # Examples section
         
     | 
| 254 | 
         
            +
                        gr.Examples(
         
     | 
| 255 | 
         
            +
                            examples=[
         
     | 
| 256 | 
         
            +
                                [
         
     | 
| 257 | 
         
            +
                                    "Love and Heartbreak", 
         
     | 
| 258 | 
         
            +
                                    "female vocal emotional piano pop",
         
     | 
| 259 | 
         
            +
                                    "en"
         
     | 
| 260 | 
         
            +
                                ],
         
     | 
| 261 | 
         
            +
                                [
         
     | 
| 262 | 
         
            +
                                    "Heroic Epic", 
         
     | 
| 263 | 
         
            +
                                    "male choir orchestral powerful",
         
     | 
| 264 | 
         
            +
                                    "zh"
         
     | 
| 265 | 
         
            +
                                ]
         
     | 
| 266 | 
         
            +
                            ],
         
     | 
| 267 | 
         
            +
                            inputs=[theme, tags_gen, language],
         
     | 
| 268 | 
         
            +
                            label="Examples: Generate from Theme"
         
     | 
| 269 | 
         
            +
                        )
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                        gr.Examples(
         
     | 
| 272 | 
         
            +
                            examples=[
         
     | 
| 273 | 
         
            +
                                [
         
     | 
| 274 | 
         
            +
                                    "acoustic folk happy", 
         
     | 
| 275 | 
         
            +
                                    """I'm sitting here in the boring room
         
     | 
| 276 | 
         
            +
                                    It's just another rainy Sunday afternoon"""
         
     | 
| 277 | 
         
            +
                                ],
         
     | 
| 278 | 
         
            +
                                [
         
     | 
| 279 | 
         
            +
                                    "electronic dance energetic",
         
     | 
| 280 | 
         
            +
                                    """We're living in a material world
         
     | 
| 281 | 
         
            +
                                    And I am a material girl"""
         
     | 
| 282 | 
         
            +
                                ]
         
     | 
| 283 | 
         
            +
                            ],
         
     | 
| 284 | 
         
            +
                            inputs=[tags_lyrics, lyrics_input],
         
     | 
| 285 | 
         
            +
                            label="Examples: Generate from Lyrics"
         
     | 
| 286 | 
         
            +
                        )
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                        # Bind functions
         
     | 
| 289 | 
         
            +
                        gen_from_theme_btn.click(
         
     | 
| 290 | 
         
            +
                            fn=R1_infer1,
         
     | 
| 291 | 
         
            +
                            inputs=[theme, tags_gen, language],
         
     | 
| 292 | 
         
            +
                            outputs=lrc_output
         
     | 
| 293 | 
         
            +
                        )
         
     | 
| 294 | 
         
            +
                        
         
     | 
| 295 | 
         
            +
                        gen_from_lyrics_btn.click(
         
     | 
| 296 | 
         
            +
                            fn=R1_infer2,
         
     | 
| 297 | 
         
            +
                            inputs=[tags_lyrics, lyrics_input],
         
     | 
| 298 | 
         
            +
                            outputs=lrc_output
         
     | 
| 299 | 
         
            +
                        )
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                tabs.select(
         
     | 
| 302 | 
         
            +
                lambda s: None, 
         
     | 
| 303 | 
         
            +
                None, 
         
     | 
| 304 | 
         
            +
                None 
         
     | 
| 305 | 
         
            +
                )
         
     | 
| 306 | 
         
            +
                
         
     | 
| 307 | 
         
            +
                lyrics_btn.click(
         
     | 
| 308 | 
         
            +
                    fn=infer_music,
         
     | 
| 309 | 
         
            +
                    inputs=[lrc, audio_prompt],
         
     | 
| 310 | 
         
            +
                    outputs=audio_output
         
     | 
| 311 | 
         
            +
                )
         
     | 
| 312 | 
         
            +
                
         
     | 
| 313 | 
         
            +
            demo.queue().launch(show_api=False, show_error=True)
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 318 | 
         
            +
                demo.launch()
         
     | 
    	
        diffrhythm/config/defaults.ini
    ADDED
    
    | 
         @@ -0,0 +1,94 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
             
     | 
| 2 | 
         
            +
            [DEFAULTS]
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            #name of the run
         
     | 
| 5 | 
         
            +
            exp_name = F5
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            # the batch size
         
     | 
| 8 | 
         
            +
            batch_size = 8 
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # the chunk size
         
     | 
| 11 | 
         
            +
            max_frames = 3000 
         
     | 
| 12 | 
         
            +
            min_frames = 10
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            # number of CPU workers for the DataLoader
         
     | 
| 15 | 
         
            +
            num_workers = 4
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            # the random seed
         
     | 
| 18 | 
         
            +
            seed = 42
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            # Batches for gradient accumulation
         
     | 
| 21 | 
         
            +
            accum_batches = 1
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            # Number of steps between checkpoints
         
     | 
| 24 | 
         
            +
            checkpoint_every = 10000                              
         
     | 
| 25 | 
         
            +
                                 
         
     | 
| 26 | 
         
            +
            # trainer checkpoint file to restart training from
         
     | 
| 27 | 
         
            +
            ckpt_path = ''
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            # model checkpoint file to start a new training run from
         
     | 
| 30 | 
         
            +
            pretrained_ckpt_path = ''
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            # Checkpoint path for the pretransform model if needed
         
     | 
| 33 | 
         
            +
            pretransform_ckpt_path = ''
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            # configuration model specifying model hyperparameters
         
     | 
| 36 | 
         
            +
            model_config = ''
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            # configuration for datasets
         
     | 
| 39 | 
         
            +
            dataset_config = ''
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            # directory to save the checkpoints in
         
     | 
| 42 | 
         
            +
            save_dir = ''
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            # grad norm
         
     | 
| 45 | 
         
            +
            max_grad_norm = 1.0
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            # grad accu
         
     | 
| 48 | 
         
            +
            grad_accumulation_steps = 1
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            # lr
         
     | 
| 51 | 
         
            +
            learning_rate = 7.5e-5
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            # epoch
         
     | 
| 54 | 
         
            +
            epochs = 110 
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            # warmup steps
         
     | 
| 57 | 
         
            +
            num_warmup_updates = 2000
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            # save checkpoint per steps
         
     | 
| 60 | 
         
            +
            save_per_updates = 5000
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            # save last checkpoint per steps
         
     | 
| 63 | 
         
            +
            last_per_steps = 5000
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            prompt_path = "/mnt/sfs/music/lance/style-lance-full|/mnt/sfs/music/lance/style-lance-cnen-music-second"
         
     | 
| 66 | 
         
            +
            lrc_path = "/mnt/sfs/music/lance/lrc-lance-emb-full|/mnt/sfs/music/lance/lrc-lance-cnen-second"
         
     | 
| 67 | 
         
            +
            latent_path = "/mnt/sfs/music/lance/latent-lance|/mnt/sfs/music/lance/latent-lance-cnen-music-second-1|/mnt/sfs/music/lance/latent-lance-cnen-music-second-2"
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            audio_drop_prob = 0.3
         
     | 
| 70 | 
         
            +
            cond_drop_prob = 0.0
         
     | 
| 71 | 
         
            +
            style_drop_prob = 0.1
         
     | 
| 72 | 
         
            +
            lrc_drop_prob = 0.1
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            align_lyrics = 0
         
     | 
| 75 | 
         
            +
            lyrics_slice = 0
         
     | 
| 76 | 
         
            +
            parse_lyrics = 1
         
     | 
| 77 | 
         
            +
            skip_empty_lyrics = 0
         
     | 
| 78 | 
         
            +
            lyrics_shift = -1
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
            use_style_prompt = 1
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            tokenizer_type = gpt2
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            reset_lr = 0
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
            resumable_with_seed = 666
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
            downsample_rate = 2048
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            grad_ckpt = 0
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            dataset_path = "/mnt/sfs/music/hkchen/workspace/F5-TTS-HW/filelists/music123latent_asred_bpmstyle_cnen_pure1"
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
            pure_prob = 0.0
         
     | 
    	
        diffrhythm/config/diffrhythm-1b.json
    ADDED
    
    | 
         @@ -0,0 +1,13 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
                "model_type": "diffrhythm",
         
     | 
| 3 | 
         
            +
                "model": {
         
     | 
| 4 | 
         
            +
                    "dim": 2048,
         
     | 
| 5 | 
         
            +
                    "depth": 16,
         
     | 
| 6 | 
         
            +
                    "heads": 32,
         
     | 
| 7 | 
         
            +
                    "ff_mult": 4,
         
     | 
| 8 | 
         
            +
                    "text_dim": 512,
         
     | 
| 9 | 
         
            +
                    "conv_layers": 4,
         
     | 
| 10 | 
         
            +
                    "mel_dim": 64,
         
     | 
| 11 | 
         
            +
                    "text_num_embeds": 363
         
     | 
| 12 | 
         
            +
                }
         
     | 
| 13 | 
         
            +
            }
         
     | 
    	
        diffrhythm/model/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,6 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from diffrhythm.model.cfm import CFM
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from diffrhythm.model.dit import DiT
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            __all__ = ["CFM"]
         
     | 
    	
        diffrhythm/model/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | 
         Binary file (290 Bytes). View file 
     | 
| 
         | 
    	
        diffrhythm/model/__pycache__/__init__.cpython-312.pyc
    ADDED
    
    | 
         Binary file (508 Bytes). View file 
     | 
| 
         | 
    	
        diffrhythm/model/__pycache__/cfm.cpython-310.pyc
    ADDED
    
    | 
         Binary file (6.28 kB). View file 
     | 
| 
         | 
    	
        diffrhythm/model/__pycache__/cfm.cpython-312.pyc
    ADDED
    
    | 
         Binary file (10.7 kB). View file 
     | 
| 
         | 
    	
        diffrhythm/model/__pycache__/custom_dataset.cpython-310.pyc
    ADDED
    
    | 
         Binary file (11.5 kB). View file 
     | 
| 
         | 
    	
        diffrhythm/model/__pycache__/custom_dataset_lrc_emb.cpython-310.pyc
    ADDED
    
    | 
         Binary file (10.5 kB). View file 
     | 
| 
         | 
    	
        diffrhythm/model/__pycache__/dataset.cpython-310.pyc
    ADDED
    
    | 
         Binary file (8.04 kB). View file 
     | 
| 
         | 
    	
        diffrhythm/model/__pycache__/dit.cpython-310.pyc
    ADDED
    
    | 
         Binary file (5.61 kB). View file 
     | 
| 
         | 
    	
        diffrhythm/model/__pycache__/modules.cpython-310.pyc
    ADDED
    
    | 
         Binary file (15.9 kB). View file 
     | 
| 
         | 
    	
        diffrhythm/model/__pycache__/trainer.cpython-310.pyc
    ADDED
    
    | 
         Binary file (9.13 kB). View file 
     | 
| 
         | 
    	
        diffrhythm/model/__pycache__/utils.cpython-310.pyc
    ADDED
    
    | 
         Binary file (6.03 kB). View file 
     | 
| 
         | 
    	
        diffrhythm/model/cfm.py
    ADDED
    
    | 
         @@ -0,0 +1,315 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            ein notation:
         
     | 
| 3 | 
         
            +
            b - batch
         
     | 
| 4 | 
         
            +
            n - sequence
         
     | 
| 5 | 
         
            +
            nt - text sequence
         
     | 
| 6 | 
         
            +
            nw - raw wave length
         
     | 
| 7 | 
         
            +
            d - dimension
         
     | 
| 8 | 
         
            +
            """
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from __future__ import annotations
         
     | 
| 11 | 
         
            +
            from typing import Callable
         
     | 
| 12 | 
         
            +
            from random import random
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            import torch
         
     | 
| 15 | 
         
            +
            from torch import nn
         
     | 
| 16 | 
         
            +
            import torch
         
     | 
| 17 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 18 | 
         
            +
            from torch.nn.utils.rnn import pad_sequence
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from torchdiffeq import odeint
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            from diffrhythm.model.modules import MelSpec
         
     | 
| 23 | 
         
            +
            from diffrhythm.model.utils import (
         
     | 
| 24 | 
         
            +
                default,
         
     | 
| 25 | 
         
            +
                exists,
         
     | 
| 26 | 
         
            +
                list_str_to_idx,
         
     | 
| 27 | 
         
            +
                list_str_to_tensor,
         
     | 
| 28 | 
         
            +
                lens_to_mask,
         
     | 
| 29 | 
         
            +
                mask_from_frac_lengths,
         
     | 
| 30 | 
         
            +
            )
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            def custom_mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"], device, max_seq_len):  # noqa: F722 F821
         
     | 
| 33 | 
         
            +
                max_seq_len = max_seq_len
         
     | 
| 34 | 
         
            +
                seq = torch.arange(max_seq_len, device=device).long()
         
     | 
| 35 | 
         
            +
                start_mask = seq[None, :] >= start[:, None]
         
     | 
| 36 | 
         
            +
                end_mask = seq[None, :] < end[:, None]
         
     | 
| 37 | 
         
            +
                return start_mask & end_mask
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            class CFM(nn.Module):
         
     | 
| 40 | 
         
            +
                def __init__(
         
     | 
| 41 | 
         
            +
                    self,
         
     | 
| 42 | 
         
            +
                    transformer: nn.Module,
         
     | 
| 43 | 
         
            +
                    sigma=0.0,
         
     | 
| 44 | 
         
            +
                    odeint_kwargs: dict = dict(
         
     | 
| 45 | 
         
            +
                        # atol = 1e-5,
         
     | 
| 46 | 
         
            +
                        # rtol = 1e-5,
         
     | 
| 47 | 
         
            +
                        method="euler" # 'midpoint'
         
     | 
| 48 | 
         
            +
                        # method="adaptive_heun"  # dopri5
         
     | 
| 49 | 
         
            +
                    ),
         
     | 
| 50 | 
         
            +
                    odeint_options: dict = dict(
         
     | 
| 51 | 
         
            +
                        min_step=0.05
         
     | 
| 52 | 
         
            +
                    ),
         
     | 
| 53 | 
         
            +
                    audio_drop_prob=0.3,
         
     | 
| 54 | 
         
            +
                    cond_drop_prob=0.2,
         
     | 
| 55 | 
         
            +
                    style_drop_prob=0.1,
         
     | 
| 56 | 
         
            +
                    lrc_drop_prob=0.1,
         
     | 
| 57 | 
         
            +
                    num_channels=None,
         
     | 
| 58 | 
         
            +
                    frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
         
     | 
| 59 | 
         
            +
                    vocab_char_map: dict[str:int] | None = None,
         
     | 
| 60 | 
         
            +
                    use_style_prompt: bool = False
         
     | 
| 61 | 
         
            +
                ):
         
     | 
| 62 | 
         
            +
                    super().__init__()
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    self.frac_lengths_mask = frac_lengths_mask
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    self.num_channels = num_channels
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    # classifier-free guidance
         
     | 
| 69 | 
         
            +
                    self.audio_drop_prob = audio_drop_prob
         
     | 
| 70 | 
         
            +
                    self.cond_drop_prob = cond_drop_prob
         
     | 
| 71 | 
         
            +
                    self.style_drop_prob = style_drop_prob
         
     | 
| 72 | 
         
            +
                    self.lrc_drop_prob = lrc_drop_prob
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    print(f"audio drop prob -> {self.audio_drop_prob}; style_drop_prob -> {self.style_drop_prob}; lrc_drop_prob: {self.lrc_drop_prob}")
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    # transformer
         
     | 
| 77 | 
         
            +
                    self.transformer = transformer
         
     | 
| 78 | 
         
            +
                    dim = transformer.dim
         
     | 
| 79 | 
         
            +
                    self.dim = dim
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    # conditional flow related
         
     | 
| 82 | 
         
            +
                    self.sigma = sigma
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    # sampling related
         
     | 
| 85 | 
         
            +
                    self.odeint_kwargs = odeint_kwargs
         
     | 
| 86 | 
         
            +
                    # print(f"ODE SOLVER: {self.odeint_kwargs['method']}")
         
     | 
| 87 | 
         
            +
                    
         
     | 
| 88 | 
         
            +
                    self.odeint_options = odeint_options
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    # vocab map for tokenization
         
     | 
| 91 | 
         
            +
                    self.vocab_char_map = vocab_char_map
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    self.use_style_prompt = use_style_prompt
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                @property
         
     | 
| 96 | 
         
            +
                def device(self):
         
     | 
| 97 | 
         
            +
                    return next(self.parameters()).device
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                @torch.no_grad()
         
     | 
| 100 | 
         
            +
                def sample(
         
     | 
| 101 | 
         
            +
                    self,
         
     | 
| 102 | 
         
            +
                    cond: float["b n d"] | float["b nw"],  # noqa: F722
         
     | 
| 103 | 
         
            +
                    text: int["b nt"] | list[str],  # noqa: F722
         
     | 
| 104 | 
         
            +
                    duration: int | int["b"],  # noqa: F821
         
     | 
| 105 | 
         
            +
                    *,
         
     | 
| 106 | 
         
            +
                    style_prompt = None,
         
     | 
| 107 | 
         
            +
                    style_prompt_lens = None,
         
     | 
| 108 | 
         
            +
                    negative_style_prompt = None,
         
     | 
| 109 | 
         
            +
                    lens: int["b"] | None = None,  # noqa: F821
         
     | 
| 110 | 
         
            +
                    steps=32,
         
     | 
| 111 | 
         
            +
                    cfg_strength=4.0,
         
     | 
| 112 | 
         
            +
                    sway_sampling_coef=None,
         
     | 
| 113 | 
         
            +
                    seed: int | None = None,
         
     | 
| 114 | 
         
            +
                    max_duration=4096,
         
     | 
| 115 | 
         
            +
                    vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None,  # noqa: F722
         
     | 
| 116 | 
         
            +
                    no_ref_audio=False,
         
     | 
| 117 | 
         
            +
                    duplicate_test=False,
         
     | 
| 118 | 
         
            +
                    t_inter=0.1,
         
     | 
| 119 | 
         
            +
                    edit_mask=None,
         
     | 
| 120 | 
         
            +
                    start_time=None,
         
     | 
| 121 | 
         
            +
                    latent_pred_start_frame=0,
         
     | 
| 122 | 
         
            +
                    latent_pred_end_frame=2048,
         
     | 
| 123 | 
         
            +
                ):
         
     | 
| 124 | 
         
            +
                    self.eval()
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                    if next(self.parameters()).dtype == torch.float16:
         
     | 
| 127 | 
         
            +
                        cond = cond.half()
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    # raw wave
         
     | 
| 130 | 
         
            +
                    
         
     | 
| 131 | 
         
            +
                    if cond.shape[1] > duration:
         
     | 
| 132 | 
         
            +
                        cond = cond[:, :duration, :]
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    if cond.ndim == 2:
         
     | 
| 135 | 
         
            +
                        cond = self.mel_spec(cond)
         
     | 
| 136 | 
         
            +
                        cond = cond.permute(0, 2, 1)
         
     | 
| 137 | 
         
            +
                        assert cond.shape[-1] == self.num_channels
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    batch, cond_seq_len, device = *cond.shape[:2], cond.device
         
     | 
| 140 | 
         
            +
                    if not exists(lens):
         
     | 
| 141 | 
         
            +
                        lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    # text
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                    if isinstance(text, list):
         
     | 
| 146 | 
         
            +
                        if exists(self.vocab_char_map):
         
     | 
| 147 | 
         
            +
                            text = list_str_to_idx(text, self.vocab_char_map).to(device)
         
     | 
| 148 | 
         
            +
                        else:
         
     | 
| 149 | 
         
            +
                            text = list_str_to_tensor(text).to(device)
         
     | 
| 150 | 
         
            +
                        assert text.shape[0] == batch
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    if exists(text):
         
     | 
| 153 | 
         
            +
                        text_lens = (text != -1).sum(dim=-1)
         
     | 
| 154 | 
         
            +
                        #lens = torch.maximum(text_lens, lens)  # make sure lengths are at least those of the text characters
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                    # duration
         
     | 
| 157 | 
         
            +
                    # import pdb; pdb.set_trace()
         
     | 
| 158 | 
         
            +
                    cond_mask = lens_to_mask(lens)
         
     | 
| 159 | 
         
            +
                    if edit_mask is not None:
         
     | 
| 160 | 
         
            +
                        cond_mask = cond_mask & edit_mask
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    latent_pred_start_frame = torch.tensor([latent_pred_start_frame]).to(cond.device)
         
     | 
| 163 | 
         
            +
                    latent_pred_end_frame = duration
         
     | 
| 164 | 
         
            +
                    latent_pred_end_frame = torch.tensor([latent_pred_end_frame]).to(cond.device)
         
     | 
| 165 | 
         
            +
                    fixed_span_mask = custom_mask_from_start_end_indices(cond_seq_len, latent_pred_start_frame, latent_pred_end_frame, device=cond.device, max_seq_len=duration)
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    fixed_span_mask = fixed_span_mask.unsqueeze(-1)
         
     | 
| 168 | 
         
            +
                    step_cond = torch.where(fixed_span_mask, torch.zeros_like(cond), cond)
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    if isinstance(duration, int):
         
     | 
| 171 | 
         
            +
                        duration = torch.full((batch,), duration, device=device, dtype=torch.long)
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                    # duration = torch.maximum(lens + 1, duration)  # just add one token so something is generated
         
     | 
| 174 | 
         
            +
                    duration = duration.clamp(max=max_duration)
         
     | 
| 175 | 
         
            +
                    max_duration = duration.amax()
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                    # duplicate test corner for inner time step oberservation
         
     | 
| 178 | 
         
            +
                    if duplicate_test:
         
     | 
| 179 | 
         
            +
                        test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                    # cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) # [b, t, d]
         
     | 
| 182 | 
         
            +
                    # cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) # [b, max_duration]
         
     | 
| 183 | 
         
            +
                    # cond_mask = cond_mask.unsqueeze(-1) #[b, t, d]
         
     | 
| 184 | 
         
            +
                    # step_cond = torch.where(
         
     | 
| 185 | 
         
            +
                    #     cond_mask, cond, torch.zeros_like(cond)
         
     | 
| 186 | 
         
            +
                    # )  # allow direct control (cut cond audio) with lens passed in
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    if batch > 1:
         
     | 
| 189 | 
         
            +
                        mask = lens_to_mask(duration)
         
     | 
| 190 | 
         
            +
                    else:  # save memory and speed up, as single inference need no mask currently
         
     | 
| 191 | 
         
            +
                        mask = None
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                    # test for no ref audio
         
     | 
| 194 | 
         
            +
                    if no_ref_audio:
         
     | 
| 195 | 
         
            +
                        cond = torch.zeros_like(cond)
         
     | 
| 196 | 
         
            +
                        
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    def fn(t, x):
         
     | 
| 199 | 
         
            +
                        # at each step, conditioning is fixed
         
     | 
| 200 | 
         
            +
                        # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                        # predict flow
         
     | 
| 203 | 
         
            +
                        pred = self.transformer(
         
     | 
| 204 | 
         
            +
                            x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, drop_prompt=False,
         
     | 
| 205 | 
         
            +
                            style_prompt=style_prompt, style_prompt_lens=style_prompt_lens, start_time=start_time
         
     | 
| 206 | 
         
            +
                        )
         
     | 
| 207 | 
         
            +
                        if cfg_strength < 1e-5:
         
     | 
| 208 | 
         
            +
                            return pred
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                        null_pred = self.transformer(
         
     | 
| 211 | 
         
            +
                            x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, drop_prompt=False,
         
     | 
| 212 | 
         
            +
                            style_prompt=negative_style_prompt, style_prompt_lens=style_prompt_lens, start_time=start_time
         
     | 
| 213 | 
         
            +
                        )
         
     | 
| 214 | 
         
            +
                        return pred + (pred - null_pred) * cfg_strength
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    # noise input
         
     | 
| 217 | 
         
            +
                    # to make sure batch inference result is same with different batch size, and for sure single inference
         
     | 
| 218 | 
         
            +
                    # still some difference maybe due to convolutional layers
         
     | 
| 219 | 
         
            +
                    y0 = []
         
     | 
| 220 | 
         
            +
                    for dur in duration:
         
     | 
| 221 | 
         
            +
                        if exists(seed):
         
     | 
| 222 | 
         
            +
                            torch.manual_seed(seed)
         
     | 
| 223 | 
         
            +
                        y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
         
     | 
| 224 | 
         
            +
                    y0 = pad_sequence(y0, padding_value=0, batch_first=True)
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    t_start = 0
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                    # duplicate test corner for inner time step oberservation
         
     | 
| 229 | 
         
            +
                    if duplicate_test:
         
     | 
| 230 | 
         
            +
                        t_start = t_inter
         
     | 
| 231 | 
         
            +
                        y0 = (1 - t_start) * y0 + t_start * test_cond
         
     | 
| 232 | 
         
            +
                        steps = int(steps * (1 - t_start))
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                    t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
         
     | 
| 235 | 
         
            +
                    if sway_sampling_coef is not None:
         
     | 
| 236 | 
         
            +
                        t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                    trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                    sampled = trajectory[-1]
         
     | 
| 241 | 
         
            +
                    out = sampled
         
     | 
| 242 | 
         
            +
                    # out = torch.where(cond_mask, cond, out)
         
     | 
| 243 | 
         
            +
                    out = torch.where(fixed_span_mask, out, cond)
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                    if exists(vocoder):
         
     | 
| 246 | 
         
            +
                        out = out.permute(0, 2, 1)
         
     | 
| 247 | 
         
            +
                        out = vocoder(out)
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                    return out, trajectory
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                def forward(
         
     | 
| 252 | 
         
            +
                    self,
         
     | 
| 253 | 
         
            +
                    inp: float["b n d"] | float["b nw"],  # mel or raw wave  # noqa: F722
         
     | 
| 254 | 
         
            +
                    text: int["b nt"] | list[str],  # noqa: F722
         
     | 
| 255 | 
         
            +
                    style_prompt = None,
         
     | 
| 256 | 
         
            +
                    style_prompt_lens = None,
         
     | 
| 257 | 
         
            +
                    lens: int["b"] | None = None,  # noqa: F821
         
     | 
| 258 | 
         
            +
                    noise_scheduler: str | None = None,
         
     | 
| 259 | 
         
            +
                    grad_ckpt = False,
         
     | 
| 260 | 
         
            +
                    start_time = None,
         
     | 
| 261 | 
         
            +
                ):
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                    # lens and mask
         
     | 
| 266 | 
         
            +
                    if not exists(lens):
         
     | 
| 267 | 
         
            +
                        lens = torch.full((batch,), seq_len, device=device)
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    mask = lens_to_mask(lens, length=seq_len)  # useless here, as collate_fn will pad to max length in batch
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                    # get a random span to mask out for training conditionally
         
     | 
| 272 | 
         
            +
                    frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
         
     | 
| 273 | 
         
            +
                    rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                    if exists(mask):
         
     | 
| 276 | 
         
            +
                        rand_span_mask = mask
         
     | 
| 277 | 
         
            +
                        # rand_span_mask &= mask
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
                    # mel is x1
         
     | 
| 280 | 
         
            +
                    x1 = inp
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                    # x0 is gaussian noise
         
     | 
| 283 | 
         
            +
                    x0 = torch.randn_like(x1)
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                    # time step
         
     | 
| 286 | 
         
            +
                    # time = torch.rand((batch,), dtype=dtype, device=self.device)
         
     | 
| 287 | 
         
            +
                    time = torch.normal(mean=0, std=1, size=(batch,), device=self.device)
         
     | 
| 288 | 
         
            +
                    time = torch.nn.functional.sigmoid(time)
         
     | 
| 289 | 
         
            +
                    # TODO. noise_scheduler
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    # sample xt (φ_t(x) in the paper)
         
     | 
| 292 | 
         
            +
                    t = time.unsqueeze(-1).unsqueeze(-1)
         
     | 
| 293 | 
         
            +
                    φ = (1 - t) * x0 + t * x1
         
     | 
| 294 | 
         
            +
                    flow = x1 - x0
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    # only predict what is within the random mask span for infilling
         
     | 
| 297 | 
         
            +
                    cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
                    # transformer and cfg training with a drop rate
         
     | 
| 300 | 
         
            +
                    drop_audio_cond = random() < self.audio_drop_prob  # p_drop in voicebox paper
         
     | 
| 301 | 
         
            +
                    drop_text = random() < self.lrc_drop_prob
         
     | 
| 302 | 
         
            +
                    drop_prompt = random() < self.style_drop_prob
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                    # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
         
     | 
| 305 | 
         
            +
                    # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
         
     | 
| 306 | 
         
            +
                    pred = self.transformer(
         
     | 
| 307 | 
         
            +
                        x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, drop_prompt=drop_prompt,
         
     | 
| 308 | 
         
            +
                        style_prompt=style_prompt, style_prompt_lens=style_prompt_lens, grad_ckpt=grad_ckpt, start_time=start_time
         
     | 
| 309 | 
         
            +
                    )
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                    # flow matching loss
         
     | 
| 312 | 
         
            +
                    loss = F.mse_loss(pred, flow, reduction="none")
         
     | 
| 313 | 
         
            +
                    loss = loss[rand_span_mask]
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
                    return loss.mean(), cond, pred
         
     | 
    	
        diffrhythm/model/dit.py
    ADDED
    
    | 
         @@ -0,0 +1,221 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            ein notation:
         
     | 
| 3 | 
         
            +
            b - batch
         
     | 
| 4 | 
         
            +
            n - sequence
         
     | 
| 5 | 
         
            +
            nt - text sequence
         
     | 
| 6 | 
         
            +
            nw - raw wave length
         
     | 
| 7 | 
         
            +
            d - dimension
         
     | 
| 8 | 
         
            +
            """
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from __future__ import annotations
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            import torch
         
     | 
| 13 | 
         
            +
            from torch import nn
         
     | 
| 14 | 
         
            +
            import torch
         
     | 
| 15 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            from x_transformers.x_transformers import RotaryEmbedding
         
     | 
| 18 | 
         
            +
            from transformers.models.llama.modeling_llama import LlamaDecoderLayer
         
     | 
| 19 | 
         
            +
            from transformers.models.llama import LlamaConfig
         
     | 
| 20 | 
         
            +
            from torch.utils.checkpoint import checkpoint
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            from diffrhythm.model.modules import (
         
     | 
| 23 | 
         
            +
                TimestepEmbedding,
         
     | 
| 24 | 
         
            +
                ConvNeXtV2Block,
         
     | 
| 25 | 
         
            +
                ConvPositionEmbedding,
         
     | 
| 26 | 
         
            +
                DiTBlock,
         
     | 
| 27 | 
         
            +
                AdaLayerNormZero_Final,
         
     | 
| 28 | 
         
            +
                precompute_freqs_cis,
         
     | 
| 29 | 
         
            +
                get_pos_embed_indices,
         
     | 
| 30 | 
         
            +
            )
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            # Text embedding
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            class TextEmbedding(nn.Module):
         
     | 
| 37 | 
         
            +
                def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
         
     | 
| 38 | 
         
            +
                    super().__init__()
         
     | 
| 39 | 
         
            +
                    self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim)  # use 0 as filler token
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    if conv_layers > 0:
         
     | 
| 42 | 
         
            +
                        self.extra_modeling = True
         
     | 
| 43 | 
         
            +
                        self.precompute_max_pos = 4096  # ~44s of 24khz audio
         
     | 
| 44 | 
         
            +
                        self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
         
     | 
| 45 | 
         
            +
                        self.text_blocks = nn.Sequential(
         
     | 
| 46 | 
         
            +
                            *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
         
     | 
| 47 | 
         
            +
                        )
         
     | 
| 48 | 
         
            +
                    else:
         
     | 
| 49 | 
         
            +
                        self.extra_modeling = False
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                def forward(self, text: int["b nt"], seq_len, drop_text=False):  # noqa: F722
         
     | 
| 52 | 
         
            +
                    #text = text + 1  # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
         
     | 
| 53 | 
         
            +
                    #text = text[:, :seq_len]  # curtail if character tokens are more than the mel spec tokens
         
     | 
| 54 | 
         
            +
                    batch, text_len = text.shape[0], text.shape[1]
         
     | 
| 55 | 
         
            +
                    #text = F.pad(text, (0, seq_len - text_len), value=0)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    if drop_text:  # cfg for text
         
     | 
| 58 | 
         
            +
                        text = torch.zeros_like(text)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    text = self.text_embed(text)  # b n -> b n d
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    # possible extra modeling
         
     | 
| 63 | 
         
            +
                    if self.extra_modeling:
         
     | 
| 64 | 
         
            +
                        # sinus pos emb
         
     | 
| 65 | 
         
            +
                        batch_start = torch.zeros((batch,), dtype=torch.long)
         
     | 
| 66 | 
         
            +
                        pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
         
     | 
| 67 | 
         
            +
                        text_pos_embed = self.freqs_cis[pos_idx]
         
     | 
| 68 | 
         
            +
                        text = text + text_pos_embed
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                        # convnextv2 blocks
         
     | 
| 71 | 
         
            +
                        text = self.text_blocks(text)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    return text
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            # noised input audio and context mixing embedding
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            class InputEmbedding(nn.Module):
         
     | 
| 80 | 
         
            +
                def __init__(self, mel_dim, text_dim, out_dim, cond_dim):
         
     | 
| 81 | 
         
            +
                    super().__init__()
         
     | 
| 82 | 
         
            +
                    self.proj = nn.Linear(mel_dim * 2 + text_dim + cond_dim * 2, out_dim)
         
     | 
| 83 | 
         
            +
                    self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], style_emb, time_emb, drop_audio_cond=False):  # noqa: F722
         
     | 
| 86 | 
         
            +
                    if drop_audio_cond:  # cfg for cond audio
         
     | 
| 87 | 
         
            +
                        cond = torch.zeros_like(cond)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    style_emb = style_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
         
     | 
| 90 | 
         
            +
                    time_emb = time_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
         
     | 
| 91 | 
         
            +
                    # print(x.shape, cond.shape, text_embed.shape, style_emb.shape, time_emb.shape)
         
     | 
| 92 | 
         
            +
                    x = self.proj(torch.cat((x, cond, text_embed, style_emb, time_emb), dim=-1))
         
     | 
| 93 | 
         
            +
                    x = self.conv_pos_embed(x) + x
         
     | 
| 94 | 
         
            +
                    return x
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            # Transformer backbone using DiT blocks
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            class DiT(nn.Module):
         
     | 
| 101 | 
         
            +
                def __init__(
         
     | 
| 102 | 
         
            +
                    self,
         
     | 
| 103 | 
         
            +
                    *,
         
     | 
| 104 | 
         
            +
                    dim,
         
     | 
| 105 | 
         
            +
                    depth=8,
         
     | 
| 106 | 
         
            +
                    heads=8,
         
     | 
| 107 | 
         
            +
                    dim_head=64,
         
     | 
| 108 | 
         
            +
                    dropout=0.1,
         
     | 
| 109 | 
         
            +
                    ff_mult=4,
         
     | 
| 110 | 
         
            +
                    mel_dim=100,
         
     | 
| 111 | 
         
            +
                    text_num_embeds=256,
         
     | 
| 112 | 
         
            +
                    text_dim=None,
         
     | 
| 113 | 
         
            +
                    conv_layers=0,
         
     | 
| 114 | 
         
            +
                    long_skip_connection=False,
         
     | 
| 115 | 
         
            +
                    use_style_prompt=False
         
     | 
| 116 | 
         
            +
                ):
         
     | 
| 117 | 
         
            +
                    super().__init__()
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    cond_dim = 512
         
     | 
| 120 | 
         
            +
                    self.time_embed = TimestepEmbedding(cond_dim)
         
     | 
| 121 | 
         
            +
                    self.start_time_embed = TimestepEmbedding(cond_dim)
         
     | 
| 122 | 
         
            +
                    if text_dim is None:
         
     | 
| 123 | 
         
            +
                        text_dim = mel_dim
         
     | 
| 124 | 
         
            +
                    self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
         
     | 
| 125 | 
         
            +
                    self.input_embed = InputEmbedding(mel_dim, text_dim, dim, cond_dim=cond_dim)
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    #self.rotary_embed = RotaryEmbedding(dim_head)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    self.dim = dim
         
     | 
| 130 | 
         
            +
                    self.depth = depth
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    #self.transformer_blocks = nn.ModuleList(
         
     | 
| 133 | 
         
            +
                    #    [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout, use_style_prompt=use_style_prompt) for _ in range(depth)]
         
     | 
| 134 | 
         
            +
                    #)
         
     | 
| 135 | 
         
            +
                    llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu')
         
     | 
| 136 | 
         
            +
                    llama_config._attn_implementation = 'sdpa'
         
     | 
| 137 | 
         
            +
                    self.transformer_blocks = nn.ModuleList(
         
     | 
| 138 | 
         
            +
                        [LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)]
         
     | 
| 139 | 
         
            +
                    )
         
     | 
| 140 | 
         
            +
                    self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    self.text_fusion_linears = nn.ModuleList(
         
     | 
| 143 | 
         
            +
                        [
         
     | 
| 144 | 
         
            +
                            nn.Sequential(
         
     | 
| 145 | 
         
            +
                                nn.Linear(cond_dim, dim),
         
     | 
| 146 | 
         
            +
                                nn.SiLU()
         
     | 
| 147 | 
         
            +
                            ) for i in range(depth // 2)
         
     | 
| 148 | 
         
            +
                        ]
         
     | 
| 149 | 
         
            +
                    )
         
     | 
| 150 | 
         
            +
                    for layer in self.text_fusion_linears:
         
     | 
| 151 | 
         
            +
                        for p in layer.parameters():
         
     | 
| 152 | 
         
            +
                            p.detach().zero_()
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    self.norm_out = AdaLayerNormZero_Final(dim, cond_dim)  # final modulation
         
     | 
| 155 | 
         
            +
                    self.proj_out = nn.Linear(dim, mel_dim)
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                    # if use_style_prompt:
         
     | 
| 158 | 
         
            +
                    #     self.prompt_rnn = nn.LSTM(64, cond_dim, 1, batch_first=True)
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                def forward(
         
     | 
| 162 | 
         
            +
                    self,
         
     | 
| 163 | 
         
            +
                    x: float["b n d"],  # nosied input audio  # noqa: F722
         
     | 
| 164 | 
         
            +
                    cond: float["b n d"],  # masked cond audio  # noqa: F722
         
     | 
| 165 | 
         
            +
                    text: int["b nt"],  # text  # noqa: F722
         
     | 
| 166 | 
         
            +
                    time: float["b"] | float[""],  # time step  # noqa: F821 F722
         
     | 
| 167 | 
         
            +
                    drop_audio_cond,  # cfg for cond audio
         
     | 
| 168 | 
         
            +
                    drop_text,  # cfg for text
         
     | 
| 169 | 
         
            +
                    drop_prompt=False,
         
     | 
| 170 | 
         
            +
                    style_prompt=None, # [b d t]
         
     | 
| 171 | 
         
            +
                    style_prompt_lens=None,
         
     | 
| 172 | 
         
            +
                    mask: bool["b n"] | None = None,  # noqa: F722
         
     | 
| 173 | 
         
            +
                    grad_ckpt=False,
         
     | 
| 174 | 
         
            +
                    start_time=None,
         
     | 
| 175 | 
         
            +
                ):
         
     | 
| 176 | 
         
            +
                    batch, seq_len = x.shape[0], x.shape[1]
         
     | 
| 177 | 
         
            +
                    if time.ndim == 0:
         
     | 
| 178 | 
         
            +
                        time = time.repeat(batch)
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                    # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
         
     | 
| 181 | 
         
            +
                    t = self.time_embed(time)
         
     | 
| 182 | 
         
            +
                    s_t = self.start_time_embed(start_time)
         
     | 
| 183 | 
         
            +
                    c = t + s_t
         
     | 
| 184 | 
         
            +
                    text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    # import pdb; pdb.set_trace()
         
     | 
| 187 | 
         
            +
                    if drop_prompt:
         
     | 
| 188 | 
         
            +
                        style_prompt = torch.zeros_like(style_prompt)
         
     | 
| 189 | 
         
            +
                    # if self.training:
         
     | 
| 190 | 
         
            +
                    #     packed_style_prompt = torch.nn.utils.rnn.pack_padded_sequence(style_prompt.transpose(1, 2), style_prompt_lens.cpu(), batch_first=True, enforce_sorted=False)
         
     | 
| 191 | 
         
            +
                    # else:
         
     | 
| 192 | 
         
            +
                    #     packed_style_prompt = style_prompt.transpose(1, 2)
         
     | 
| 193 | 
         
            +
                    #print(packed_style_prompt.shape)
         
     | 
| 194 | 
         
            +
                    # _, style_emb = self.prompt_rnn.forward(packed_style_prompt)
         
     | 
| 195 | 
         
            +
                    # _, (h_n, c_n) = self.prompt_rnn.forward(packed_style_prompt)
         
     | 
| 196 | 
         
            +
                    # style_emb = h_n.squeeze(0) # 1, B, dim -> B, dim
         
     | 
| 197 | 
         
            +
                    
         
     | 
| 198 | 
         
            +
                    style_emb = style_prompt # [b, 512]
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                    x = self.input_embed(x, cond, text_embed, style_emb, c, drop_audio_cond=drop_audio_cond)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    if self.long_skip_connection is not None:
         
     | 
| 203 | 
         
            +
                        residual = x
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                    pos_ids = torch.arange(x.shape[1], device=x.device)
         
     | 
| 206 | 
         
            +
                    pos_ids = pos_ids.unsqueeze(0).repeat(x.shape[0], 1)
         
     | 
| 207 | 
         
            +
                    for i, block in enumerate(self.transformer_blocks):
         
     | 
| 208 | 
         
            +
                        if not grad_ckpt:
         
     | 
| 209 | 
         
            +
                            x, *_ = block(x, position_ids=pos_ids)
         
     | 
| 210 | 
         
            +
                        else:
         
     | 
| 211 | 
         
            +
                            x, *_ = checkpoint(block, x, position_ids=pos_ids, use_reentrant=False)
         
     | 
| 212 | 
         
            +
                        if i < self.depth // 2:
         
     | 
| 213 | 
         
            +
                            x = x + self.text_fusion_linears[i](text_embed)
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                    if self.long_skip_connection is not None:
         
     | 
| 216 | 
         
            +
                        x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                    x = self.norm_out(x, c)
         
     | 
| 219 | 
         
            +
                    output = self.proj_out(x)
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    return output
         
     | 
    	
        diffrhythm/model/modules.py
    ADDED
    
    | 
         @@ -0,0 +1,611 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            ein notation:
         
     | 
| 3 | 
         
            +
            b - batch
         
     | 
| 4 | 
         
            +
            n - sequence
         
     | 
| 5 | 
         
            +
            nt - text sequence
         
     | 
| 6 | 
         
            +
            nw - raw wave length
         
     | 
| 7 | 
         
            +
            d - dimension
         
     | 
| 8 | 
         
            +
            """
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from __future__ import annotations
         
     | 
| 11 | 
         
            +
            from typing import Optional
         
     | 
| 12 | 
         
            +
            import math
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            import torch
         
     | 
| 15 | 
         
            +
            from torch import nn
         
     | 
| 16 | 
         
            +
            import torch
         
     | 
| 17 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 18 | 
         
            +
            import torchaudio
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from x_transformers.x_transformers import apply_rotary_pos_emb
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            class FiLMLayer(nn.Module):
         
     | 
| 25 | 
         
            +
                """
         
     | 
| 26 | 
         
            +
                Feature-wise Linear Modulation (FiLM) layer
         
     | 
| 27 | 
         
            +
                Reference: https://arxiv.org/abs/1709.07871
         
     | 
| 28 | 
         
            +
                """
         
     | 
| 29 | 
         
            +
                def __init__(self, in_channels, cond_channels):
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                    super(FiLMLayer, self).__init__()
         
     | 
| 32 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 33 | 
         
            +
                    self.film = nn.Conv1d(cond_channels, in_channels * 2, 1)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                def forward(self, x, c):
         
     | 
| 36 | 
         
            +
                    gamma, beta = torch.chunk(self.film(c.unsqueeze(2)), chunks=2, dim=1)
         
     | 
| 37 | 
         
            +
                    gamma = gamma.transpose(1, 2)
         
     | 
| 38 | 
         
            +
                    beta = beta.transpose(1, 2)
         
     | 
| 39 | 
         
            +
                    # print(gamma.shape, beta.shape)
         
     | 
| 40 | 
         
            +
                    return gamma * x + beta
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            # raw wav to mel spec
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            class MelSpec(nn.Module):
         
     | 
| 46 | 
         
            +
                def __init__(
         
     | 
| 47 | 
         
            +
                    self,
         
     | 
| 48 | 
         
            +
                    filter_length=1024,
         
     | 
| 49 | 
         
            +
                    hop_length=256,
         
     | 
| 50 | 
         
            +
                    win_length=1024,
         
     | 
| 51 | 
         
            +
                    n_mel_channels=100,
         
     | 
| 52 | 
         
            +
                    target_sample_rate=24_000,
         
     | 
| 53 | 
         
            +
                    normalize=False,
         
     | 
| 54 | 
         
            +
                    power=1,
         
     | 
| 55 | 
         
            +
                    norm=None,
         
     | 
| 56 | 
         
            +
                    center=True,
         
     | 
| 57 | 
         
            +
                ):
         
     | 
| 58 | 
         
            +
                    super().__init__()
         
     | 
| 59 | 
         
            +
                    self.n_mel_channels = n_mel_channels
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    self.mel_stft = torchaudio.transforms.MelSpectrogram(
         
     | 
| 62 | 
         
            +
                        sample_rate=target_sample_rate,
         
     | 
| 63 | 
         
            +
                        n_fft=filter_length,
         
     | 
| 64 | 
         
            +
                        win_length=win_length,
         
     | 
| 65 | 
         
            +
                        hop_length=hop_length,
         
     | 
| 66 | 
         
            +
                        n_mels=n_mel_channels,
         
     | 
| 67 | 
         
            +
                        power=power,
         
     | 
| 68 | 
         
            +
                        center=center,
         
     | 
| 69 | 
         
            +
                        normalized=normalize,
         
     | 
| 70 | 
         
            +
                        norm=norm,
         
     | 
| 71 | 
         
            +
                    )
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    self.register_buffer("dummy", torch.tensor(0), persistent=False)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                def forward(self, inp):
         
     | 
| 76 | 
         
            +
                    if len(inp.shape) == 3:
         
     | 
| 77 | 
         
            +
                        inp = inp.squeeze(1)  # 'b 1 nw -> b nw'
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                    assert len(inp.shape) == 2
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    if self.dummy.device != inp.device:
         
     | 
| 82 | 
         
            +
                        self.to(inp.device)
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    mel = self.mel_stft(inp)
         
     | 
| 85 | 
         
            +
                    mel = mel.clamp(min=1e-5).log()
         
     | 
| 86 | 
         
            +
                    return mel
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
            # sinusoidal position embedding
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            class SinusPositionEmbedding(nn.Module):
         
     | 
| 93 | 
         
            +
                def __init__(self, dim):
         
     | 
| 94 | 
         
            +
                    super().__init__()
         
     | 
| 95 | 
         
            +
                    self.dim = dim
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                def forward(self, x, scale=1000):
         
     | 
| 98 | 
         
            +
                    device = x.device
         
     | 
| 99 | 
         
            +
                    half_dim = self.dim // 2
         
     | 
| 100 | 
         
            +
                    emb = math.log(10000) / (half_dim - 1)
         
     | 
| 101 | 
         
            +
                    emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
         
     | 
| 102 | 
         
            +
                    emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
         
     | 
| 103 | 
         
            +
                    emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
         
     | 
| 104 | 
         
            +
                    return emb
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            # convolutional position embedding
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            class ConvPositionEmbedding(nn.Module):
         
     | 
| 111 | 
         
            +
                def __init__(self, dim, kernel_size=31, groups=16):
         
     | 
| 112 | 
         
            +
                    super().__init__()
         
     | 
| 113 | 
         
            +
                    assert kernel_size % 2 != 0
         
     | 
| 114 | 
         
            +
                    self.conv1d = nn.Sequential(
         
     | 
| 115 | 
         
            +
                        nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
         
     | 
| 116 | 
         
            +
                        nn.Mish(),
         
     | 
| 117 | 
         
            +
                        nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
         
     | 
| 118 | 
         
            +
                        nn.Mish(),
         
     | 
| 119 | 
         
            +
                    )
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):  # noqa: F722
         
     | 
| 122 | 
         
            +
                    if mask is not None:
         
     | 
| 123 | 
         
            +
                        mask = mask[..., None]
         
     | 
| 124 | 
         
            +
                        x = x.masked_fill(~mask, 0.0)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                    x = x.permute(0, 2, 1)
         
     | 
| 127 | 
         
            +
                    x = self.conv1d(x)
         
     | 
| 128 | 
         
            +
                    out = x.permute(0, 2, 1)
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    if mask is not None:
         
     | 
| 131 | 
         
            +
                        out = out.masked_fill(~mask, 0.0)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    return out
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            # rotary positional embedding related
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
            def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
         
     | 
| 140 | 
         
            +
                # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
         
     | 
| 141 | 
         
            +
                # has some connection to NTK literature
         
     | 
| 142 | 
         
            +
                # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
         
     | 
| 143 | 
         
            +
                # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
         
     | 
| 144 | 
         
            +
                theta *= theta_rescale_factor ** (dim / (dim - 2))
         
     | 
| 145 | 
         
            +
                freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
         
     | 
| 146 | 
         
            +
                t = torch.arange(end, device=freqs.device)  # type: ignore
         
     | 
| 147 | 
         
            +
                freqs = torch.outer(t, freqs).float()  # type: ignore
         
     | 
| 148 | 
         
            +
                freqs_cos = torch.cos(freqs)  # real part
         
     | 
| 149 | 
         
            +
                freqs_sin = torch.sin(freqs)  # imaginary part
         
     | 
| 150 | 
         
            +
                return torch.cat([freqs_cos, freqs_sin], dim=-1)
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
            def get_pos_embed_indices(start, length, max_pos, scale=1.0):
         
     | 
| 154 | 
         
            +
                # length = length if isinstance(length, int) else length.max()
         
     | 
| 155 | 
         
            +
                scale = scale * torch.ones_like(start, dtype=torch.float32)  # in case scale is a scalar
         
     | 
| 156 | 
         
            +
                pos = (
         
     | 
| 157 | 
         
            +
                    start.unsqueeze(1)
         
     | 
| 158 | 
         
            +
                    + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
         
     | 
| 159 | 
         
            +
                )
         
     | 
| 160 | 
         
            +
                # avoid extra long error.
         
     | 
| 161 | 
         
            +
                pos = torch.where(pos < max_pos, pos, max_pos - 1)
         
     | 
| 162 | 
         
            +
                return pos
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
            # Global Response Normalization layer (Instance Normalization ?)
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
            class GRN(nn.Module):
         
     | 
| 169 | 
         
            +
                def __init__(self, dim):
         
     | 
| 170 | 
         
            +
                    super().__init__()
         
     | 
| 171 | 
         
            +
                    self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
         
     | 
| 172 | 
         
            +
                    self.beta = nn.Parameter(torch.zeros(1, 1, dim))
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                def forward(self, x):
         
     | 
| 175 | 
         
            +
                    Gx = torch.norm(x, p=2, dim=1, keepdim=True)
         
     | 
| 176 | 
         
            +
                    Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
         
     | 
| 177 | 
         
            +
                    return self.gamma * (x * Nx) + self.beta + x
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
            # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
         
     | 
| 181 | 
         
            +
            # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
            class ConvNeXtV2Block(nn.Module):
         
     | 
| 185 | 
         
            +
                def __init__(
         
     | 
| 186 | 
         
            +
                    self,
         
     | 
| 187 | 
         
            +
                    dim: int,
         
     | 
| 188 | 
         
            +
                    intermediate_dim: int,
         
     | 
| 189 | 
         
            +
                    dilation: int = 1,
         
     | 
| 190 | 
         
            +
                ):
         
     | 
| 191 | 
         
            +
                    super().__init__()
         
     | 
| 192 | 
         
            +
                    padding = (dilation * (7 - 1)) // 2
         
     | 
| 193 | 
         
            +
                    self.dwconv = nn.Conv1d(
         
     | 
| 194 | 
         
            +
                        dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
         
     | 
| 195 | 
         
            +
                    )  # depthwise conv
         
     | 
| 196 | 
         
            +
                    self.norm = nn.LayerNorm(dim, eps=1e-6)
         
     | 
| 197 | 
         
            +
                    self.pwconv1 = nn.Linear(dim, intermediate_dim)  # pointwise/1x1 convs, implemented with linear layers
         
     | 
| 198 | 
         
            +
                    self.act = nn.GELU()
         
     | 
| 199 | 
         
            +
                    self.grn = GRN(intermediate_dim)
         
     | 
| 200 | 
         
            +
                    self.pwconv2 = nn.Linear(intermediate_dim, dim)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 203 | 
         
            +
                    residual = x
         
     | 
| 204 | 
         
            +
                    x = x.transpose(1, 2)  # b n d -> b d n
         
     | 
| 205 | 
         
            +
                    x = self.dwconv(x)
         
     | 
| 206 | 
         
            +
                    x = x.transpose(1, 2)  # b d n -> b n d
         
     | 
| 207 | 
         
            +
                    x = self.norm(x)
         
     | 
| 208 | 
         
            +
                    x = self.pwconv1(x)
         
     | 
| 209 | 
         
            +
                    x = self.act(x)
         
     | 
| 210 | 
         
            +
                    x = self.grn(x)
         
     | 
| 211 | 
         
            +
                    x = self.pwconv2(x)
         
     | 
| 212 | 
         
            +
                    return residual + x
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
            # AdaLayerNormZero
         
     | 
| 216 | 
         
            +
            # return with modulated x for attn input, and params for later mlp modulation
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
            class AdaLayerNormZero(nn.Module):
         
     | 
| 220 | 
         
            +
                def __init__(self, dim):
         
     | 
| 221 | 
         
            +
                    super().__init__()
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                    self.silu = nn.SiLU()
         
     | 
| 224 | 
         
            +
                    self.linear = nn.Linear(dim, dim * 6)
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                def forward(self, x, emb=None):
         
     | 
| 229 | 
         
            +
                    emb = self.linear(self.silu(emb))
         
     | 
| 230 | 
         
            +
                    shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                    x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
         
     | 
| 233 | 
         
            +
                    return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
            # AdaLayerNormZero for final layer
         
     | 
| 237 | 
         
            +
            # return only with modulated x for attn input, cuz no more mlp modulation
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
            class AdaLayerNormZero_Final(nn.Module):
         
     | 
| 241 | 
         
            +
                def __init__(self, dim, cond_dim):
         
     | 
| 242 | 
         
            +
                    super().__init__()
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                    self.silu = nn.SiLU()
         
     | 
| 245 | 
         
            +
                    self.linear = nn.Linear(cond_dim, dim * 2)
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                    self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                def forward(self, x, emb):
         
     | 
| 250 | 
         
            +
                    emb = self.linear(self.silu(emb))
         
     | 
| 251 | 
         
            +
                    scale, shift = torch.chunk(emb, 2, dim=1)
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
         
     | 
| 254 | 
         
            +
                    return x
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
            # FeedForward
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
            class FeedForward(nn.Module):
         
     | 
| 261 | 
         
            +
                def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
         
     | 
| 262 | 
         
            +
                    super().__init__()
         
     | 
| 263 | 
         
            +
                    inner_dim = int(dim * mult)
         
     | 
| 264 | 
         
            +
                    dim_out = dim_out if dim_out is not None else dim
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    activation = nn.GELU(approximate=approximate)
         
     | 
| 267 | 
         
            +
                    #activation = nn.SiLU()
         
     | 
| 268 | 
         
            +
                    project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
         
     | 
| 269 | 
         
            +
                    self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                def forward(self, x):
         
     | 
| 272 | 
         
            +
                    return self.ff(x)
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
            # Attention with possible joint part
         
     | 
| 276 | 
         
            +
            # modified from diffusers/src/diffusers/models/attention_processor.py
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
            class Attention(nn.Module):
         
     | 
| 280 | 
         
            +
                def __init__(
         
     | 
| 281 | 
         
            +
                    self,
         
     | 
| 282 | 
         
            +
                    processor: JointAttnProcessor | AttnProcessor,
         
     | 
| 283 | 
         
            +
                    dim: int,
         
     | 
| 284 | 
         
            +
                    heads: int = 8,
         
     | 
| 285 | 
         
            +
                    dim_head: int = 64,
         
     | 
| 286 | 
         
            +
                    dropout: float = 0.0,
         
     | 
| 287 | 
         
            +
                    context_dim: Optional[int] = None,  # if not None -> joint attention
         
     | 
| 288 | 
         
            +
                    context_pre_only=None,
         
     | 
| 289 | 
         
            +
                ):
         
     | 
| 290 | 
         
            +
                    super().__init__()
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                    if not hasattr(F, "scaled_dot_product_attention"):
         
     | 
| 293 | 
         
            +
                        raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                    self.processor = processor
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                    self.dim = dim
         
     | 
| 298 | 
         
            +
                    self.heads = heads
         
     | 
| 299 | 
         
            +
                    self.inner_dim = dim_head * heads
         
     | 
| 300 | 
         
            +
                    self.dropout = dropout
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                    self.context_dim = context_dim
         
     | 
| 303 | 
         
            +
                    self.context_pre_only = context_pre_only
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                    self.to_q = nn.Linear(dim, self.inner_dim)
         
     | 
| 306 | 
         
            +
                    self.to_k = nn.Linear(dim, self.inner_dim)
         
     | 
| 307 | 
         
            +
                    self.to_v = nn.Linear(dim, self.inner_dim)
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                    if self.context_dim is not None:
         
     | 
| 310 | 
         
            +
                        self.to_k_c = nn.Linear(context_dim, self.inner_dim)
         
     | 
| 311 | 
         
            +
                        self.to_v_c = nn.Linear(context_dim, self.inner_dim)
         
     | 
| 312 | 
         
            +
                        if self.context_pre_only is not None:
         
     | 
| 313 | 
         
            +
                            self.to_q_c = nn.Linear(context_dim, self.inner_dim)
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
                    self.to_out = nn.ModuleList([])
         
     | 
| 316 | 
         
            +
                    self.to_out.append(nn.Linear(self.inner_dim, dim))
         
     | 
| 317 | 
         
            +
                    self.to_out.append(nn.Dropout(dropout))
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                    if self.context_pre_only is not None and not self.context_pre_only:
         
     | 
| 320 | 
         
            +
                        self.to_out_c = nn.Linear(self.inner_dim, dim)
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                def forward(
         
     | 
| 323 | 
         
            +
                    self,
         
     | 
| 324 | 
         
            +
                    x: float["b n d"],  # noised input x  # noqa: F722
         
     | 
| 325 | 
         
            +
                    c: float["b n d"] = None,  # context c  # noqa: F722
         
     | 
| 326 | 
         
            +
                    mask: bool["b n"] | None = None,  # noqa: F722
         
     | 
| 327 | 
         
            +
                    rope=None,  # rotary position embedding for x
         
     | 
| 328 | 
         
            +
                    c_rope=None,  # rotary position embedding for c
         
     | 
| 329 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 330 | 
         
            +
                    if c is not None:
         
     | 
| 331 | 
         
            +
                        return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
         
     | 
| 332 | 
         
            +
                    else:
         
     | 
| 333 | 
         
            +
                        return self.processor(self, x, mask=mask, rope=rope)
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
            # Attention processor
         
     | 
| 337 | 
         
            +
             
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
            class AttnProcessor:
         
     | 
| 340 | 
         
            +
                def __init__(self):
         
     | 
| 341 | 
         
            +
                    pass
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                def __call__(
         
     | 
| 344 | 
         
            +
                    self,
         
     | 
| 345 | 
         
            +
                    attn: Attention,
         
     | 
| 346 | 
         
            +
                    x: float["b n d"],  # noised input x  # noqa: F722
         
     | 
| 347 | 
         
            +
                    mask: bool["b n"] | None = None,  # noqa: F722
         
     | 
| 348 | 
         
            +
                    rope=None,  # rotary position embedding
         
     | 
| 349 | 
         
            +
                ) -> torch.FloatTensor:
         
     | 
| 350 | 
         
            +
                    batch_size = x.shape[0]
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
                    # `sample` projections.
         
     | 
| 353 | 
         
            +
                    query = attn.to_q(x)
         
     | 
| 354 | 
         
            +
                    key = attn.to_k(x)
         
     | 
| 355 | 
         
            +
                    value = attn.to_v(x)
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                    # apply rotary position embedding
         
     | 
| 358 | 
         
            +
                    if rope is not None:
         
     | 
| 359 | 
         
            +
                        freqs, xpos_scale = rope
         
     | 
| 360 | 
         
            +
                        q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
                        query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
         
     | 
| 363 | 
         
            +
                        key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                    # attention
         
     | 
| 366 | 
         
            +
                    inner_dim = key.shape[-1]
         
     | 
| 367 | 
         
            +
                    head_dim = inner_dim // attn.heads
         
     | 
| 368 | 
         
            +
                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         
     | 
| 369 | 
         
            +
                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         
     | 
| 370 | 
         
            +
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                    # mask. e.g. inference got a batch with different target durations, mask out the padding
         
     | 
| 373 | 
         
            +
                    if mask is not None:
         
     | 
| 374 | 
         
            +
                        attn_mask = mask
         
     | 
| 375 | 
         
            +
                        attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)  # 'b n -> b 1 1 n'
         
     | 
| 376 | 
         
            +
                        attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
         
     | 
| 377 | 
         
            +
                    else:
         
     | 
| 378 | 
         
            +
                        attn_mask = None
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
                    x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
         
     | 
| 381 | 
         
            +
                    x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
         
     | 
| 382 | 
         
            +
                    x = x.to(query.dtype)
         
     | 
| 383 | 
         
            +
             
     | 
| 384 | 
         
            +
                    # linear proj
         
     | 
| 385 | 
         
            +
                    x = attn.to_out[0](x)
         
     | 
| 386 | 
         
            +
                    # dropout
         
     | 
| 387 | 
         
            +
                    x = attn.to_out[1](x)
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                    if mask is not None:
         
     | 
| 390 | 
         
            +
                        mask = mask.unsqueeze(-1)
         
     | 
| 391 | 
         
            +
                        x = x.masked_fill(~mask, 0.0)
         
     | 
| 392 | 
         
            +
             
     | 
| 393 | 
         
            +
                    return x
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
            # Joint Attention processor for MM-DiT
         
     | 
| 397 | 
         
            +
            # modified from diffusers/src/diffusers/models/attention_processor.py
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
             
     | 
| 400 | 
         
            +
            class JointAttnProcessor:
         
     | 
| 401 | 
         
            +
                def __init__(self):
         
     | 
| 402 | 
         
            +
                    pass
         
     | 
| 403 | 
         
            +
             
     | 
| 404 | 
         
            +
                def __call__(
         
     | 
| 405 | 
         
            +
                    self,
         
     | 
| 406 | 
         
            +
                    attn: Attention,
         
     | 
| 407 | 
         
            +
                    x: float["b n d"],  # noised input x  # noqa: F722
         
     | 
| 408 | 
         
            +
                    c: float["b nt d"] = None,  # context c, here text # noqa: F722
         
     | 
| 409 | 
         
            +
                    mask: bool["b n"] | None = None,  # noqa: F722
         
     | 
| 410 | 
         
            +
                    rope=None,  # rotary position embedding for x
         
     | 
| 411 | 
         
            +
                    c_rope=None,  # rotary position embedding for c
         
     | 
| 412 | 
         
            +
                ) -> torch.FloatTensor:
         
     | 
| 413 | 
         
            +
                    residual = x
         
     | 
| 414 | 
         
            +
             
     | 
| 415 | 
         
            +
                    batch_size = c.shape[0]
         
     | 
| 416 | 
         
            +
             
     | 
| 417 | 
         
            +
                    # `sample` projections.
         
     | 
| 418 | 
         
            +
                    query = attn.to_q(x)
         
     | 
| 419 | 
         
            +
                    key = attn.to_k(x)
         
     | 
| 420 | 
         
            +
                    value = attn.to_v(x)
         
     | 
| 421 | 
         
            +
             
     | 
| 422 | 
         
            +
                    # `context` projections.
         
     | 
| 423 | 
         
            +
                    c_query = attn.to_q_c(c)
         
     | 
| 424 | 
         
            +
                    c_key = attn.to_k_c(c)
         
     | 
| 425 | 
         
            +
                    c_value = attn.to_v_c(c)
         
     | 
| 426 | 
         
            +
             
     | 
| 427 | 
         
            +
                    # apply rope for context and noised input independently
         
     | 
| 428 | 
         
            +
                    if rope is not None:
         
     | 
| 429 | 
         
            +
                        freqs, xpos_scale = rope
         
     | 
| 430 | 
         
            +
                        q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
         
     | 
| 431 | 
         
            +
                        query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
         
     | 
| 432 | 
         
            +
                        key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
         
     | 
| 433 | 
         
            +
                    if c_rope is not None:
         
     | 
| 434 | 
         
            +
                        freqs, xpos_scale = c_rope
         
     | 
| 435 | 
         
            +
                        q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
         
     | 
| 436 | 
         
            +
                        c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
         
     | 
| 437 | 
         
            +
                        c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
         
     | 
| 438 | 
         
            +
             
     | 
| 439 | 
         
            +
                    # attention
         
     | 
| 440 | 
         
            +
                    query = torch.cat([query, c_query], dim=1)
         
     | 
| 441 | 
         
            +
                    key = torch.cat([key, c_key], dim=1)
         
     | 
| 442 | 
         
            +
                    value = torch.cat([value, c_value], dim=1)
         
     | 
| 443 | 
         
            +
             
     | 
| 444 | 
         
            +
                    inner_dim = key.shape[-1]
         
     | 
| 445 | 
         
            +
                    head_dim = inner_dim // attn.heads
         
     | 
| 446 | 
         
            +
                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         
     | 
| 447 | 
         
            +
                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         
     | 
| 448 | 
         
            +
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         
     | 
| 449 | 
         
            +
             
     | 
| 450 | 
         
            +
                    # mask. e.g. inference got a batch with different target durations, mask out the padding
         
     | 
| 451 | 
         
            +
                    if mask is not None:
         
     | 
| 452 | 
         
            +
                        attn_mask = F.pad(mask, (0, c.shape[1]), value=True)  # no mask for c (text)
         
     | 
| 453 | 
         
            +
                        attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)  # 'b n -> b 1 1 n'
         
     | 
| 454 | 
         
            +
                        attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
         
     | 
| 455 | 
         
            +
                    else:
         
     | 
| 456 | 
         
            +
                        attn_mask = None
         
     | 
| 457 | 
         
            +
             
     | 
| 458 | 
         
            +
                    x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
         
     | 
| 459 | 
         
            +
                    x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
         
     | 
| 460 | 
         
            +
                    x = x.to(query.dtype)
         
     | 
| 461 | 
         
            +
             
     | 
| 462 | 
         
            +
                    # Split the attention outputs.
         
     | 
| 463 | 
         
            +
                    x, c = (
         
     | 
| 464 | 
         
            +
                        x[:, : residual.shape[1]],
         
     | 
| 465 | 
         
            +
                        x[:, residual.shape[1] :],
         
     | 
| 466 | 
         
            +
                    )
         
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
                    # linear proj
         
     | 
| 469 | 
         
            +
                    x = attn.to_out[0](x)
         
     | 
| 470 | 
         
            +
                    # dropout
         
     | 
| 471 | 
         
            +
                    x = attn.to_out[1](x)
         
     | 
| 472 | 
         
            +
                    if not attn.context_pre_only:
         
     | 
| 473 | 
         
            +
                        c = attn.to_out_c(c)
         
     | 
| 474 | 
         
            +
             
     | 
| 475 | 
         
            +
                    if mask is not None:
         
     | 
| 476 | 
         
            +
                        mask = mask.unsqueeze(-1)
         
     | 
| 477 | 
         
            +
                        x = x.masked_fill(~mask, 0.0)
         
     | 
| 478 | 
         
            +
                        # c = c.masked_fill(~mask, 0.)  # no mask for c (text)
         
     | 
| 479 | 
         
            +
             
     | 
| 480 | 
         
            +
                    return x, c
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
             
     | 
| 483 | 
         
            +
            # DiT Block
         
     | 
| 484 | 
         
            +
             
     | 
| 485 | 
         
            +
             
     | 
| 486 | 
         
            +
            class DiTBlock(nn.Module):
         
     | 
| 487 | 
         
            +
                def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, use_style_prompt=False):
         
     | 
| 488 | 
         
            +
                    super().__init__()
         
     | 
| 489 | 
         
            +
             
     | 
| 490 | 
         
            +
                    self.attn_norm = AdaLayerNormZero(dim)
         
     | 
| 491 | 
         
            +
                    self.attn = Attention(
         
     | 
| 492 | 
         
            +
                        processor=AttnProcessor(),
         
     | 
| 493 | 
         
            +
                        dim=dim,
         
     | 
| 494 | 
         
            +
                        heads=heads,
         
     | 
| 495 | 
         
            +
                        dim_head=dim_head,
         
     | 
| 496 | 
         
            +
                        dropout=dropout,
         
     | 
| 497 | 
         
            +
                    )
         
     | 
| 498 | 
         
            +
             
     | 
| 499 | 
         
            +
                    self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
         
     | 
| 500 | 
         
            +
                    self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
         
     | 
| 501 | 
         
            +
             
     | 
| 502 | 
         
            +
                    self.use_style_prompt = use_style_prompt
         
     | 
| 503 | 
         
            +
                    if use_style_prompt:
         
     | 
| 504 | 
         
            +
                        #self.film = FiLMLayer(dim, dim)
         
     | 
| 505 | 
         
            +
                        self.prompt_norm = AdaLayerNormZero_Final(dim)
         
     | 
| 506 | 
         
            +
             
     | 
| 507 | 
         
            +
                def forward(self, x, t, c=None, mask=None, rope=None):  # x: noised input, t: time embedding
         
     | 
| 508 | 
         
            +
                    if c is not None and self.use_style_prompt:
         
     | 
| 509 | 
         
            +
                        #x = self.film(x, c)
         
     | 
| 510 | 
         
            +
                        x = self.prompt_norm(x, c)
         
     | 
| 511 | 
         
            +
             
     | 
| 512 | 
         
            +
                    # pre-norm & modulation for attention input
         
     | 
| 513 | 
         
            +
                    norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
         
     | 
| 514 | 
         
            +
             
     | 
| 515 | 
         
            +
                    # attention
         
     | 
| 516 | 
         
            +
                    attn_output = self.attn(x=norm, mask=mask, rope=rope)
         
     | 
| 517 | 
         
            +
             
     | 
| 518 | 
         
            +
                    # process attention output for input x
         
     | 
| 519 | 
         
            +
                    x = x + gate_msa.unsqueeze(1) * attn_output
         
     | 
| 520 | 
         
            +
             
     | 
| 521 | 
         
            +
                    norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
         
     | 
| 522 | 
         
            +
                    ff_output = self.ff(norm)
         
     | 
| 523 | 
         
            +
                    x = x + gate_mlp.unsqueeze(1) * ff_output
         
     | 
| 524 | 
         
            +
             
     | 
| 525 | 
         
            +
                    return x
         
     | 
| 526 | 
         
            +
             
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
            # MMDiT Block https://arxiv.org/abs/2403.03206
         
     | 
| 529 | 
         
            +
             
     | 
| 530 | 
         
            +
             
     | 
| 531 | 
         
            +
            class MMDiTBlock(nn.Module):
         
     | 
| 532 | 
         
            +
                r"""
         
     | 
| 533 | 
         
            +
                modified from diffusers/src/diffusers/models/attention.py
         
     | 
| 534 | 
         
            +
             
     | 
| 535 | 
         
            +
                notes.
         
     | 
| 536 | 
         
            +
                _c: context related. text, cond, etc. (left part in sd3 fig2.b)
         
     | 
| 537 | 
         
            +
                _x: noised input related. (right part)
         
     | 
| 538 | 
         
            +
                context_pre_only: last layer only do prenorm + modulation cuz no more ffn
         
     | 
| 539 | 
         
            +
                """
         
     | 
| 540 | 
         
            +
             
     | 
| 541 | 
         
            +
                def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
         
     | 
| 542 | 
         
            +
                    super().__init__()
         
     | 
| 543 | 
         
            +
             
     | 
| 544 | 
         
            +
                    self.context_pre_only = context_pre_only
         
     | 
| 545 | 
         
            +
             
     | 
| 546 | 
         
            +
                    self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
         
     | 
| 547 | 
         
            +
                    self.attn_norm_x = AdaLayerNormZero(dim)
         
     | 
| 548 | 
         
            +
                    self.attn = Attention(
         
     | 
| 549 | 
         
            +
                        processor=JointAttnProcessor(),
         
     | 
| 550 | 
         
            +
                        dim=dim,
         
     | 
| 551 | 
         
            +
                        heads=heads,
         
     | 
| 552 | 
         
            +
                        dim_head=dim_head,
         
     | 
| 553 | 
         
            +
                        dropout=dropout,
         
     | 
| 554 | 
         
            +
                        context_dim=dim,
         
     | 
| 555 | 
         
            +
                        context_pre_only=context_pre_only,
         
     | 
| 556 | 
         
            +
                    )
         
     | 
| 557 | 
         
            +
             
     | 
| 558 | 
         
            +
                    if not context_pre_only:
         
     | 
| 559 | 
         
            +
                        self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
         
     | 
| 560 | 
         
            +
                        self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
         
     | 
| 561 | 
         
            +
                    else:
         
     | 
| 562 | 
         
            +
                        self.ff_norm_c = None
         
     | 
| 563 | 
         
            +
                        self.ff_c = None
         
     | 
| 564 | 
         
            +
                    self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
         
     | 
| 565 | 
         
            +
                    self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
         
     | 
| 566 | 
         
            +
             
     | 
| 567 | 
         
            +
                def forward(self, x, c, t, mask=None, rope=None, c_rope=None):  # x: noised input, c: context, t: time embedding
         
     | 
| 568 | 
         
            +
                    # pre-norm & modulation for attention input
         
     | 
| 569 | 
         
            +
                    if self.context_pre_only:
         
     | 
| 570 | 
         
            +
                        norm_c = self.attn_norm_c(c, t)
         
     | 
| 571 | 
         
            +
                    else:
         
     | 
| 572 | 
         
            +
                        norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
         
     | 
| 573 | 
         
            +
                    norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
         
     | 
| 574 | 
         
            +
             
     | 
| 575 | 
         
            +
                    # attention
         
     | 
| 576 | 
         
            +
                    x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
         
     | 
| 577 | 
         
            +
             
     | 
| 578 | 
         
            +
                    # process attention output for context c
         
     | 
| 579 | 
         
            +
                    if self.context_pre_only:
         
     | 
| 580 | 
         
            +
                        c = None
         
     | 
| 581 | 
         
            +
                    else:  # if not last layer
         
     | 
| 582 | 
         
            +
                        c = c + c_gate_msa.unsqueeze(1) * c_attn_output
         
     | 
| 583 | 
         
            +
             
     | 
| 584 | 
         
            +
                        norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
         
     | 
| 585 | 
         
            +
                        c_ff_output = self.ff_c(norm_c)
         
     | 
| 586 | 
         
            +
                        c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
         
     | 
| 587 | 
         
            +
             
     | 
| 588 | 
         
            +
                    # process attention output for input x
         
     | 
| 589 | 
         
            +
                    x = x + x_gate_msa.unsqueeze(1) * x_attn_output
         
     | 
| 590 | 
         
            +
             
     | 
| 591 | 
         
            +
                    norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
         
     | 
| 592 | 
         
            +
                    x_ff_output = self.ff_x(norm_x)
         
     | 
| 593 | 
         
            +
                    x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
         
     | 
| 594 | 
         
            +
             
     | 
| 595 | 
         
            +
                    return c, x
         
     | 
| 596 | 
         
            +
             
     | 
| 597 | 
         
            +
             
     | 
| 598 | 
         
            +
            # time step conditioning embedding
         
     | 
| 599 | 
         
            +
             
     | 
| 600 | 
         
            +
             
     | 
| 601 | 
         
            +
            class TimestepEmbedding(nn.Module):
         
     | 
| 602 | 
         
            +
                def __init__(self, dim, freq_embed_dim=256):
         
     | 
| 603 | 
         
            +
                    super().__init__()
         
     | 
| 604 | 
         
            +
                    self.time_embed = SinusPositionEmbedding(freq_embed_dim)
         
     | 
| 605 | 
         
            +
                    self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
         
     | 
| 606 | 
         
            +
             
     | 
| 607 | 
         
            +
                def forward(self, timestep: float["b"]):  # noqa: F821
         
     | 
| 608 | 
         
            +
                    time_hidden = self.time_embed(timestep)
         
     | 
| 609 | 
         
            +
                    time_hidden = time_hidden.to(timestep.dtype)
         
     | 
| 610 | 
         
            +
                    time = self.time_mlp(time_hidden)  # b d
         
     | 
| 611 | 
         
            +
                    return time
         
     | 
    	
        diffrhythm/model/trainer.py
    ADDED
    
    | 
         @@ -0,0 +1,350 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from __future__ import annotations
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
            import gc
         
     | 
| 5 | 
         
            +
            from tqdm import tqdm
         
     | 
| 6 | 
         
            +
            import wandb
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            from torch.optim import AdamW
         
     | 
| 10 | 
         
            +
            from torch.optim.lr_scheduler import LinearLR, SequentialLR, ConstantLR
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from accelerate import Accelerator
         
     | 
| 13 | 
         
            +
            from accelerate.utils import DistributedDataParallelKwargs
         
     | 
| 14 | 
         
            +
            from diffrhythm.dataset.custom_dataset_align2f5 import LanceDiffusionDataset
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from torch.utils.data import DataLoader, DistributedSampler
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            from ema_pytorch import EMA
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from diffrhythm.model import CFM
         
     | 
| 21 | 
         
            +
            from diffrhythm.model.utils import exists, default
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            import time
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            # from apex.optimizers.fused_adam import FusedAdam
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            # trainer
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            class Trainer:
         
     | 
| 31 | 
         
            +
                def __init__(
         
     | 
| 32 | 
         
            +
                    self,
         
     | 
| 33 | 
         
            +
                    model: CFM,
         
     | 
| 34 | 
         
            +
                    args,
         
     | 
| 35 | 
         
            +
                    epochs,
         
     | 
| 36 | 
         
            +
                    learning_rate,
         
     | 
| 37 | 
         
            +
                    #dataloader,
         
     | 
| 38 | 
         
            +
                    num_warmup_updates=20000,
         
     | 
| 39 | 
         
            +
                    save_per_updates=1000,
         
     | 
| 40 | 
         
            +
                    checkpoint_path=None,
         
     | 
| 41 | 
         
            +
                    batch_size=32,
         
     | 
| 42 | 
         
            +
                    batch_size_type: str = "sample",
         
     | 
| 43 | 
         
            +
                    max_samples=32,
         
     | 
| 44 | 
         
            +
                    grad_accumulation_steps=1,
         
     | 
| 45 | 
         
            +
                    max_grad_norm=1.0,
         
     | 
| 46 | 
         
            +
                    noise_scheduler: str | None = None,
         
     | 
| 47 | 
         
            +
                    duration_predictor: torch.nn.Module | None = None,
         
     | 
| 48 | 
         
            +
                    wandb_project="test_e2-tts",
         
     | 
| 49 | 
         
            +
                    wandb_run_name="test_run",
         
     | 
| 50 | 
         
            +
                    wandb_resume_id: str = None,
         
     | 
| 51 | 
         
            +
                    last_per_steps=None,
         
     | 
| 52 | 
         
            +
                    accelerate_kwargs: dict = dict(),
         
     | 
| 53 | 
         
            +
                    ema_kwargs: dict = dict(),
         
     | 
| 54 | 
         
            +
                    bnb_optimizer: bool = False,
         
     | 
| 55 | 
         
            +
                    reset_lr: bool = False,
         
     | 
| 56 | 
         
            +
                    use_style_prompt: bool = False,
         
     | 
| 57 | 
         
            +
                    grad_ckpt: bool = False
         
     | 
| 58 | 
         
            +
                ):
         
     | 
| 59 | 
         
            +
                    self.args = args
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False, )
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    logger = "wandb" if wandb.api.api_key else None
         
     | 
| 64 | 
         
            +
                    #logger = None
         
     | 
| 65 | 
         
            +
                    print(f"Using logger: {logger}")
         
     | 
| 66 | 
         
            +
                    # print("-----------1-------------")
         
     | 
| 67 | 
         
            +
                    import tbe.common
         
     | 
| 68 | 
         
            +
                    # print("-----------2-------------")
         
     | 
| 69 | 
         
            +
                    self.accelerator = Accelerator(
         
     | 
| 70 | 
         
            +
                        log_with=logger,
         
     | 
| 71 | 
         
            +
                        kwargs_handlers=[ddp_kwargs],
         
     | 
| 72 | 
         
            +
                        gradient_accumulation_steps=grad_accumulation_steps,
         
     | 
| 73 | 
         
            +
                        **accelerate_kwargs,
         
     | 
| 74 | 
         
            +
                    )
         
     | 
| 75 | 
         
            +
                    # print("-----------3-------------")
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    if logger == "wandb":
         
     | 
| 78 | 
         
            +
                        if exists(wandb_resume_id):
         
     | 
| 79 | 
         
            +
                            init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
         
     | 
| 80 | 
         
            +
                        else:
         
     | 
| 81 | 
         
            +
                            init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
         
     | 
| 82 | 
         
            +
                        self.accelerator.init_trackers(
         
     | 
| 83 | 
         
            +
                            project_name=wandb_project,
         
     | 
| 84 | 
         
            +
                            init_kwargs=init_kwargs,
         
     | 
| 85 | 
         
            +
                            config={
         
     | 
| 86 | 
         
            +
                                "epochs": epochs,
         
     | 
| 87 | 
         
            +
                                "learning_rate": learning_rate,
         
     | 
| 88 | 
         
            +
                                "num_warmup_updates": num_warmup_updates,
         
     | 
| 89 | 
         
            +
                                "batch_size": batch_size,
         
     | 
| 90 | 
         
            +
                                "batch_size_type": batch_size_type,
         
     | 
| 91 | 
         
            +
                                "max_samples": max_samples,
         
     | 
| 92 | 
         
            +
                                "grad_accumulation_steps": grad_accumulation_steps,
         
     | 
| 93 | 
         
            +
                                "max_grad_norm": max_grad_norm,
         
     | 
| 94 | 
         
            +
                                "gpus": self.accelerator.num_processes,
         
     | 
| 95 | 
         
            +
                                "noise_scheduler": noise_scheduler,
         
     | 
| 96 | 
         
            +
                            },
         
     | 
| 97 | 
         
            +
                        )
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    self.precision = self.accelerator.state.mixed_precision
         
     | 
| 100 | 
         
            +
                    self.precision = self.precision.replace("no", "fp32")
         
     | 
| 101 | 
         
            +
                    print("!!!!!!!!!!!!!!!!!", self.precision)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    self.model = model
         
     | 
| 104 | 
         
            +
                    #self.model = torch.compile(model)
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    #self.dataloader = dataloader
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    if self.is_main:
         
     | 
| 109 | 
         
            +
                        self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                        self.ema_model.to(self.accelerator.device)
         
     | 
| 112 | 
         
            +
                        if self.accelerator.state.distributed_type in ["DEEPSPEED", "FSDP"]:
         
     | 
| 113 | 
         
            +
                            self.ema_model.half()
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    self.epochs = epochs
         
     | 
| 116 | 
         
            +
                    self.num_warmup_updates = num_warmup_updates
         
     | 
| 117 | 
         
            +
                    self.save_per_updates = save_per_updates
         
     | 
| 118 | 
         
            +
                    self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
         
     | 
| 119 | 
         
            +
                    self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    self.max_samples = max_samples
         
     | 
| 122 | 
         
            +
                    self.grad_accumulation_steps = grad_accumulation_steps
         
     | 
| 123 | 
         
            +
                    self.max_grad_norm = max_grad_norm
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    self.noise_scheduler = noise_scheduler
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    self.duration_predictor = duration_predictor
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    self.reset_lr = reset_lr
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                    self.use_style_prompt = use_style_prompt
         
     | 
| 132 | 
         
            +
                    
         
     | 
| 133 | 
         
            +
                    self.grad_ckpt = grad_ckpt
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    if bnb_optimizer:
         
     | 
| 136 | 
         
            +
                        import bitsandbytes as bnb
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                        self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
         
     | 
| 139 | 
         
            +
                    else:
         
     | 
| 140 | 
         
            +
                        self.optimizer = AdamW(model.parameters(), lr=learning_rate)
         
     | 
| 141 | 
         
            +
                    #self.optimizer = FusedAdam(model.parameters(), lr=learning_rate)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    #self.model = torch.compile(self.model)
         
     | 
| 144 | 
         
            +
                    if self.accelerator.state.distributed_type == "DEEPSPEED":
         
     | 
| 145 | 
         
            +
                        self.accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = batch_size
         
     | 
| 146 | 
         
            +
                    
         
     | 
| 147 | 
         
            +
                    self.get_dataloader()
         
     | 
| 148 | 
         
            +
                    self.get_scheduler()
         
     | 
| 149 | 
         
            +
                    # self.get_constant_scheduler()
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    self.model, self.optimizer, self.scheduler, self.train_dataloader = self.accelerator.prepare(self.model, self.optimizer, self.scheduler, self.train_dataloader)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                def get_scheduler(self):
         
     | 
| 154 | 
         
            +
                    warmup_steps = (
         
     | 
| 155 | 
         
            +
                        self.num_warmup_updates * self.accelerator.num_processes
         
     | 
| 156 | 
         
            +
                    )  # consider a fixed warmup steps while using accelerate multi-gpu ddp
         
     | 
| 157 | 
         
            +
                    total_steps = len(self.train_dataloader) * self.epochs / self.grad_accumulation_steps
         
     | 
| 158 | 
         
            +
                    decay_steps = total_steps - warmup_steps
         
     | 
| 159 | 
         
            +
                    warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
         
     | 
| 160 | 
         
            +
                    decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
         
     | 
| 161 | 
         
            +
                    # constant_scheduler = ConstantLR(self.optimizer, factor=1, total_iters=decay_steps)
         
     | 
| 162 | 
         
            +
                    self.scheduler = SequentialLR(
         
     | 
| 163 | 
         
            +
                        self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
         
     | 
| 164 | 
         
            +
                    )
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                def get_constant_scheduler(self):
         
     | 
| 167 | 
         
            +
                    total_steps = len(self.train_dataloader) * self.epochs / self.grad_accumulation_steps
         
     | 
| 168 | 
         
            +
                    self.scheduler = ConstantLR(self.optimizer, factor=1, total_iters=total_steps)
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                def get_dataloader(self):
         
     | 
| 171 | 
         
            +
                    prompt_path = self.args.prompt_path.split('|')
         
     | 
| 172 | 
         
            +
                    lrc_path = self.args.lrc_path.split('|')
         
     | 
| 173 | 
         
            +
                    latent_path = self.args.latent_path.split('|')
         
     | 
| 174 | 
         
            +
                    ldd = LanceDiffusionDataset(*LanceDiffusionDataset.init_data(self.args.dataset_path), \
         
     | 
| 175 | 
         
            +
                                                    max_frames=self.args.max_frames, min_frames=self.args.min_frames, \
         
     | 
| 176 | 
         
            +
                                                    align_lyrics=self.args.align_lyrics, lyrics_slice=self.args.lyrics_slice, \
         
     | 
| 177 | 
         
            +
                                                    use_style_prompt=self.args.use_style_prompt, parse_lyrics=self.args.parse_lyrics,
         
     | 
| 178 | 
         
            +
                                                    lyrics_shift=self.args.lyrics_shift, downsample_rate=self.args.downsample_rate, \
         
     | 
| 179 | 
         
            +
                                                    skip_empty_lyrics=self.args.skip_empty_lyrics, tokenizer_type=self.args.tokenizer_type, precision=self.precision, \
         
     | 
| 180 | 
         
            +
                                                    start_time=time.time(), pure_prob=self.args.pure_prob)
         
     | 
| 181 | 
         
            +
                    
         
     | 
| 182 | 
         
            +
                    # start_time = time.time()
         
     | 
| 183 | 
         
            +
                    self.train_dataloader = DataLoader(
         
     | 
| 184 | 
         
            +
                        dataset=ldd,
         
     | 
| 185 | 
         
            +
                        batch_size=self.args.batch_size,      # 每个批次的样本数
         
     | 
| 186 | 
         
            +
                        shuffle=True,      # 是否随机打乱数据
         
     | 
| 187 | 
         
            +
                        num_workers=4,     # 用于加载数据的子进程数
         
     | 
| 188 | 
         
            +
                        pin_memory=True,   # 加速GPU训练
         
     | 
| 189 | 
         
            +
                        collate_fn=ldd.custom_collate_fn,
         
     | 
| 190 | 
         
            +
                        persistent_workers=True
         
     | 
| 191 | 
         
            +
                    )
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                @property
         
     | 
| 195 | 
         
            +
                def is_main(self):
         
     | 
| 196 | 
         
            +
                    return self.accelerator.is_main_process
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                def save_checkpoint(self, step, last=False):
         
     | 
| 199 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 200 | 
         
            +
                    if self.is_main:
         
     | 
| 201 | 
         
            +
                        checkpoint = dict(
         
     | 
| 202 | 
         
            +
                            model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
         
     | 
| 203 | 
         
            +
                            optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
         
     | 
| 204 | 
         
            +
                            ema_model_state_dict=self.ema_model.state_dict(),
         
     | 
| 205 | 
         
            +
                            scheduler_state_dict=self.scheduler.state_dict(),
         
     | 
| 206 | 
         
            +
                            step=step,
         
     | 
| 207 | 
         
            +
                        )
         
     | 
| 208 | 
         
            +
                        if not os.path.exists(self.checkpoint_path):
         
     | 
| 209 | 
         
            +
                            os.makedirs(self.checkpoint_path)
         
     | 
| 210 | 
         
            +
                        if last:
         
     | 
| 211 | 
         
            +
                            self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
         
     | 
| 212 | 
         
            +
                            print(f"Saved last checkpoint at step {step}")
         
     | 
| 213 | 
         
            +
                        else:
         
     | 
| 214 | 
         
            +
                            self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                def load_checkpoint(self):
         
     | 
| 217 | 
         
            +
                    if (
         
     | 
| 218 | 
         
            +
                        not exists(self.checkpoint_path)
         
     | 
| 219 | 
         
            +
                        or not os.path.exists(self.checkpoint_path)
         
     | 
| 220 | 
         
            +
                        or not os.listdir(self.checkpoint_path)
         
     | 
| 221 | 
         
            +
                    ):
         
     | 
| 222 | 
         
            +
                        return 0
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 225 | 
         
            +
                    if "model_last.pt" in os.listdir(self.checkpoint_path):
         
     | 
| 226 | 
         
            +
                        latest_checkpoint = "model_last.pt"
         
     | 
| 227 | 
         
            +
                    else:
         
     | 
| 228 | 
         
            +
                        latest_checkpoint = sorted(
         
     | 
| 229 | 
         
            +
                            [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
         
     | 
| 230 | 
         
            +
                            key=lambda x: int("".join(filter(str.isdigit, x))),
         
     | 
| 231 | 
         
            +
                        )[-1]
         
     | 
| 232 | 
         
            +
                    
         
     | 
| 233 | 
         
            +
                    checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                    ### **1. 过滤 `ema_model` 的不匹配参数**
         
     | 
| 236 | 
         
            +
                    if self.is_main:
         
     | 
| 237 | 
         
            +
                        ema_dict = self.ema_model.state_dict()
         
     | 
| 238 | 
         
            +
                        ema_checkpoint_dict = checkpoint["ema_model_state_dict"]
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                        filtered_ema_dict = {
         
     | 
| 241 | 
         
            +
                            k: v for k, v in ema_checkpoint_dict.items()
         
     | 
| 242 | 
         
            +
                            if k in ema_dict and ema_dict[k].shape == v.shape  # 仅加载 shape 匹配的参数
         
     | 
| 243 | 
         
            +
                        }
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                        print(f"Loading {len(filtered_ema_dict)} / {len(ema_checkpoint_dict)} ema_model params")
         
     | 
| 246 | 
         
            +
                        self.ema_model.load_state_dict(filtered_ema_dict, strict=False)
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                    ### **2. 过滤 `model` 的不匹配参数**
         
     | 
| 249 | 
         
            +
                    model_dict = self.accelerator.unwrap_model(self.model).state_dict()
         
     | 
| 250 | 
         
            +
                    checkpoint_model_dict = checkpoint["model_state_dict"]
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                    filtered_model_dict = {
         
     | 
| 253 | 
         
            +
                        k: v for k, v in checkpoint_model_dict.items()
         
     | 
| 254 | 
         
            +
                        if k in model_dict and model_dict[k].shape == v.shape  # 仅加载 shape 匹配的参数
         
     | 
| 255 | 
         
            +
                    }
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                    print(f"Loading {len(filtered_model_dict)} / {len(checkpoint_model_dict)} model params")
         
     | 
| 258 | 
         
            +
                    self.accelerator.unwrap_model(self.model).load_state_dict(filtered_model_dict, strict=False)
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                    ### **3. 加载优化器、调度器和步数**
         
     | 
| 261 | 
         
            +
                    if "step" in checkpoint:
         
     | 
| 262 | 
         
            +
                        if self.scheduler and not self.reset_lr:
         
     | 
| 263 | 
         
            +
                            self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
         
     | 
| 264 | 
         
            +
                        step = checkpoint["step"]
         
     | 
| 265 | 
         
            +
                    else:
         
     | 
| 266 | 
         
            +
                        step = 0
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                    del checkpoint
         
     | 
| 269 | 
         
            +
                    gc.collect()
         
     | 
| 270 | 
         
            +
                    print("Checkpoint loaded at step", step)
         
     | 
| 271 | 
         
            +
                    return step
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                def train(self, resumable_with_seed: int = None):
         
     | 
| 274 | 
         
            +
                    train_dataloader = self.train_dataloader
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                    start_step = self.load_checkpoint()
         
     | 
| 277 | 
         
            +
                    global_step = start_step
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
                    if resumable_with_seed > 0:
         
     | 
| 280 | 
         
            +
                        orig_epoch_step = len(train_dataloader)
         
     | 
| 281 | 
         
            +
                        skipped_epoch = int(start_step // orig_epoch_step)
         
     | 
| 282 | 
         
            +
                        skipped_batch = start_step % orig_epoch_step
         
     | 
| 283 | 
         
            +
                        skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
         
     | 
| 284 | 
         
            +
                    else:
         
     | 
| 285 | 
         
            +
                        skipped_epoch = 0
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                    for epoch in range(skipped_epoch, self.epochs):
         
     | 
| 288 | 
         
            +
                        self.model.train()
         
     | 
| 289 | 
         
            +
                        if resumable_with_seed > 0 and epoch == skipped_epoch:
         
     | 
| 290 | 
         
            +
                            progress_bar = tqdm(
         
     | 
| 291 | 
         
            +
                                skipped_dataloader,
         
     | 
| 292 | 
         
            +
                                desc=f"Epoch {epoch+1}/{self.epochs}",
         
     | 
| 293 | 
         
            +
                                unit="step",
         
     | 
| 294 | 
         
            +
                                disable=not self.accelerator.is_local_main_process,
         
     | 
| 295 | 
         
            +
                                initial=skipped_batch,
         
     | 
| 296 | 
         
            +
                                total=orig_epoch_step,
         
     | 
| 297 | 
         
            +
                                smoothing=0.15
         
     | 
| 298 | 
         
            +
                            )
         
     | 
| 299 | 
         
            +
                        else:
         
     | 
| 300 | 
         
            +
                            progress_bar = tqdm(
         
     | 
| 301 | 
         
            +
                                train_dataloader,
         
     | 
| 302 | 
         
            +
                                desc=f"Epoch {epoch+1}/{self.epochs}",
         
     | 
| 303 | 
         
            +
                                unit="step",
         
     | 
| 304 | 
         
            +
                                disable=not self.accelerator.is_local_main_process,
         
     | 
| 305 | 
         
            +
                                smoothing=0.15
         
     | 
| 306 | 
         
            +
                            )
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                        for batch in progress_bar:
         
     | 
| 309 | 
         
            +
                            with self.accelerator.accumulate(self.model):
         
     | 
| 310 | 
         
            +
                                text_inputs = batch["lrc"]
         
     | 
| 311 | 
         
            +
                                mel_spec = batch["latent"].permute(0, 2, 1)
         
     | 
| 312 | 
         
            +
                                mel_lengths = batch["latent_lengths"]
         
     | 
| 313 | 
         
            +
                                style_prompt = batch["prompt"]
         
     | 
| 314 | 
         
            +
                                style_prompt_lens = batch["prompt_lengths"]
         
     | 
| 315 | 
         
            +
                                start_time = batch["start_time"]
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                                loss, cond, pred = self.model(
         
     | 
| 318 | 
         
            +
                                    mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler,
         
     | 
| 319 | 
         
            +
                                    style_prompt=style_prompt if self.use_style_prompt else None,
         
     | 
| 320 | 
         
            +
                                    style_prompt_lens=style_prompt_lens if self.use_style_prompt else None,
         
     | 
| 321 | 
         
            +
                                    grad_ckpt=self.grad_ckpt, start_time=start_time
         
     | 
| 322 | 
         
            +
                                )
         
     | 
| 323 | 
         
            +
                                self.accelerator.backward(loss)
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                                if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
         
     | 
| 326 | 
         
            +
                                    self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                                self.optimizer.step()
         
     | 
| 329 | 
         
            +
                                self.scheduler.step()
         
     | 
| 330 | 
         
            +
                                self.optimizer.zero_grad()
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                            if self.is_main:
         
     | 
| 333 | 
         
            +
                                self.ema_model.update()
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                            global_step += 1
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                            if self.accelerator.is_local_main_process:
         
     | 
| 338 | 
         
            +
                                self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                            progress_bar.set_postfix(step=str(global_step), loss=loss.item())
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                            if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
         
     | 
| 343 | 
         
            +
                                self.save_checkpoint(global_step)
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                            if global_step % self.last_per_steps == 0:
         
     | 
| 346 | 
         
            +
                                self.save_checkpoint(global_step, last=True)
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                    self.save_checkpoint(global_step, last=True)
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                    self.accelerator.end_training()
         
     | 
    	
        diffrhythm/model/utils.py
    ADDED
    
    | 
         @@ -0,0 +1,182 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from __future__ import annotations
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
            import random
         
     | 
| 5 | 
         
            +
            from collections import defaultdict
         
     | 
| 6 | 
         
            +
            from importlib.resources import files
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            from torch.nn.utils.rnn import pad_sequence
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            # seed everything
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            def seed_everything(seed=0):
         
     | 
| 16 | 
         
            +
                random.seed(seed)
         
     | 
| 17 | 
         
            +
                os.environ["PYTHONHASHSEED"] = str(seed)
         
     | 
| 18 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 19 | 
         
            +
                torch.cuda.manual_seed(seed)
         
     | 
| 20 | 
         
            +
                torch.cuda.manual_seed_all(seed)
         
     | 
| 21 | 
         
            +
                torch.backends.cudnn.deterministic = True
         
     | 
| 22 | 
         
            +
                torch.backends.cudnn.benchmark = False
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            # helpers
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def exists(v):
         
     | 
| 29 | 
         
            +
                return v is not None
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            def default(v, d):
         
     | 
| 33 | 
         
            +
                return v if exists(v) else d
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            # tensor helpers
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]:  # noqa: F722 F821
         
     | 
| 40 | 
         
            +
                if not exists(length):
         
     | 
| 41 | 
         
            +
                    length = t.amax()
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                seq = torch.arange(length, device=t.device)
         
     | 
| 44 | 
         
            +
                return seq[None, :] < t[:, None]
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]):  # noqa: F722 F821
         
     | 
| 48 | 
         
            +
                max_seq_len = 2048
         
     | 
| 49 | 
         
            +
                seq = torch.arange(max_seq_len, device=start.device).long()
         
     | 
| 50 | 
         
            +
                start_mask = seq[None, :] >= start[:, None]
         
     | 
| 51 | 
         
            +
                end_mask = seq[None, :] < end[:, None]
         
     | 
| 52 | 
         
            +
                return start_mask & end_mask
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]):  # noqa: F722 F821
         
     | 
| 56 | 
         
            +
                lengths = (frac_lengths * seq_len).long()
         
     | 
| 57 | 
         
            +
                max_start = seq_len - lengths
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                rand = torch.rand_like(frac_lengths)
         
     | 
| 60 | 
         
            +
                start = (max_start * rand).long().clamp(min=0)
         
     | 
| 61 | 
         
            +
                end = start + lengths
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                return mask_from_start_end_indices(seq_len, start, end)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]:  # noqa: F722
         
     | 
| 67 | 
         
            +
                if not exists(mask):
         
     | 
| 68 | 
         
            +
                    return t.mean(dim=1)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
         
     | 
| 71 | 
         
            +
                num = t.sum(dim=1)
         
     | 
| 72 | 
         
            +
                den = mask.float().sum(dim=1)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                return num / den.clamp(min=1.0)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            # simple utf-8 tokenizer, since paper went character based
         
     | 
| 78 | 
         
            +
            def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]:  # noqa: F722
         
     | 
| 79 | 
         
            +
                list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text]  # ByT5 style
         
     | 
| 80 | 
         
            +
                text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
         
     | 
| 81 | 
         
            +
                return text
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            # char tokenizer, based on custom dataset's extracted .txt file
         
     | 
| 85 | 
         
            +
            def list_str_to_idx(
         
     | 
| 86 | 
         
            +
                text: list[str] | list[list[str]],
         
     | 
| 87 | 
         
            +
                vocab_char_map: dict[str, int],  # {char: idx}
         
     | 
| 88 | 
         
            +
                padding_value=-1,
         
     | 
| 89 | 
         
            +
            ) -> int["b nt"]:  # noqa: F722
         
     | 
| 90 | 
         
            +
                list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text]  # pinyin or char style
         
     | 
| 91 | 
         
            +
                text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
         
     | 
| 92 | 
         
            +
                return text
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
            # Get tokenizer
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
         
     | 
| 99 | 
         
            +
                """
         
     | 
| 100 | 
         
            +
                tokenizer   - "pinyin" do g2p for only chinese characters, need .txt vocab_file
         
     | 
| 101 | 
         
            +
                            - "char" for char-wise tokenizer, need .txt vocab_file
         
     | 
| 102 | 
         
            +
                            - "byte" for utf-8 tokenizer
         
     | 
| 103 | 
         
            +
                            - "custom" if you're directly passing in a path to the vocab.txt you want to use
         
     | 
| 104 | 
         
            +
                vocab_size  - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
         
     | 
| 105 | 
         
            +
                            - if use "char", derived from unfiltered character & symbol counts of custom dataset
         
     | 
| 106 | 
         
            +
                            - if use "byte", set to 256 (unicode byte range)
         
     | 
| 107 | 
         
            +
                """
         
     | 
| 108 | 
         
            +
                if tokenizer in ["pinyin", "char"]:
         
     | 
| 109 | 
         
            +
                    tokenizer_path = os.path.join(files("diffrhythm").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
         
     | 
| 110 | 
         
            +
                    with open(tokenizer_path, "r", encoding="utf-8") as f:
         
     | 
| 111 | 
         
            +
                        vocab_char_map = {}
         
     | 
| 112 | 
         
            +
                        for i, char in enumerate(f):
         
     | 
| 113 | 
         
            +
                            vocab_char_map[char[:-1]] = i
         
     | 
| 114 | 
         
            +
                    vocab_size = len(vocab_char_map)
         
     | 
| 115 | 
         
            +
                    assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                elif tokenizer == "byte":
         
     | 
| 118 | 
         
            +
                    vocab_char_map = None
         
     | 
| 119 | 
         
            +
                    vocab_size = 256
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                elif tokenizer == "custom":
         
     | 
| 122 | 
         
            +
                    with open(dataset_name, "r", encoding="utf-8") as f:
         
     | 
| 123 | 
         
            +
                        vocab_char_map = {}
         
     | 
| 124 | 
         
            +
                        for i, char in enumerate(f):
         
     | 
| 125 | 
         
            +
                            vocab_char_map[char[:-1]] = i
         
     | 
| 126 | 
         
            +
                    vocab_size = len(vocab_char_map)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                return vocab_char_map, vocab_size
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
            # convert char to pinyin
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
            def convert_char_to_pinyin(text_list, polyphone=True):
         
     | 
| 135 | 
         
            +
                final_text_list = []
         
     | 
| 136 | 
         
            +
                god_knows_why_en_testset_contains_zh_quote = str.maketrans(
         
     | 
| 137 | 
         
            +
                    {"“": '"', "”": '"', "‘": "'", "’": "'"}
         
     | 
| 138 | 
         
            +
                )  # in case librispeech (orig no-pc) test-clean
         
     | 
| 139 | 
         
            +
                custom_trans = str.maketrans({";": ","})  # add custom trans here, to address oov
         
     | 
| 140 | 
         
            +
                for text in text_list:
         
     | 
| 141 | 
         
            +
                    char_list = []
         
     | 
| 142 | 
         
            +
                    text = text.translate(god_knows_why_en_testset_contains_zh_quote)
         
     | 
| 143 | 
         
            +
                    text = text.translate(custom_trans)
         
     | 
| 144 | 
         
            +
                    for seg in jieba.cut(text):
         
     | 
| 145 | 
         
            +
                        seg_byte_len = len(bytes(seg, "UTF-8"))
         
     | 
| 146 | 
         
            +
                        if seg_byte_len == len(seg):  # if pure alphabets and symbols
         
     | 
| 147 | 
         
            +
                            if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
         
     | 
| 148 | 
         
            +
                                char_list.append(" ")
         
     | 
| 149 | 
         
            +
                            char_list.extend(seg)
         
     | 
| 150 | 
         
            +
                        elif polyphone and seg_byte_len == 3 * len(seg):  # if pure chinese characters
         
     | 
| 151 | 
         
            +
                            seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
         
     | 
| 152 | 
         
            +
                            for c in seg:
         
     | 
| 153 | 
         
            +
                                if c not in "。,、;:?!《》【】—…":
         
     | 
| 154 | 
         
            +
                                    char_list.append(" ")
         
     | 
| 155 | 
         
            +
                                char_list.append(c)
         
     | 
| 156 | 
         
            +
                        else:  # if mixed chinese characters, alphabets and symbols
         
     | 
| 157 | 
         
            +
                            for c in seg:
         
     | 
| 158 | 
         
            +
                                if ord(c) < 256:
         
     | 
| 159 | 
         
            +
                                    char_list.extend(c)
         
     | 
| 160 | 
         
            +
                                else:
         
     | 
| 161 | 
         
            +
                                    if c not in "。,、;:?!《》【】—…":
         
     | 
| 162 | 
         
            +
                                        char_list.append(" ")
         
     | 
| 163 | 
         
            +
                                        char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
         
     | 
| 164 | 
         
            +
                                    else:  # if is zh punc
         
     | 
| 165 | 
         
            +
                                        char_list.append(c)
         
     | 
| 166 | 
         
            +
                    final_text_list.append(char_list)
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                return final_text_list
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
            # filter func for dirty data with many repetitions
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
            def repetition_found(text, length=2, tolerance=10):
         
     | 
| 175 | 
         
            +
                pattern_count = defaultdict(int)
         
     | 
| 176 | 
         
            +
                for i in range(len(text) - length + 1):
         
     | 
| 177 | 
         
            +
                    pattern = text[i : i + length]
         
     | 
| 178 | 
         
            +
                    pattern_count[pattern] += 1
         
     | 
| 179 | 
         
            +
                for pattern, count in pattern_count.items():
         
     | 
| 180 | 
         
            +
                    if count > tolerance:
         
     | 
| 181 | 
         
            +
                        return True
         
     | 
| 182 | 
         
            +
                return False
         
     | 
    	
        prompt/gift_of_the_world.wav
    ADDED
    
    | 
         Binary file (960 kB). View file 
     | 
| 
         | 
    	
        prompt/little_happiness.wav
    ADDED
    
    | 
         Binary file (960 kB). View file 
     | 
| 
         | 
    	
        prompt/little_talks.wav
    ADDED
    
    | 
         Binary file (960 kB). View file 
     | 
| 
         | 
    	
        prompt/ltwyl.wav
    ADDED
    
    | 
         Binary file (882 kB). View file 
     | 
| 
         | 
    	
        prompt/most_beautiful_expectation.wav
    ADDED
    
    | 
         Binary file (960 kB). View file 
     | 
| 
         |