Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		root
		
	commited on
		
		
					Commit 
							
							·
						
						d658154
	
1
								Parent(s):
							
							b25cf95
								
add auto prompt and interface
Browse files- app.py +142 -107
 - codeclm/models/builders.py +0 -6
 - codeclm/models/codeclm.py +13 -17
 - codeclm/models/lm_levo.py +17 -24
 - codeclm/modules/conditioners.py +0 -167
 - codeclm/tokenizer/audio_tokenizer.py +8 -662
 - codeclm/trainer/codec_song_pl.py +6 -550
 - conf/infer.yaml +0 -152
 - generate.py +52 -31
 - generate.sh +4 -6
 - levo_inference.py +24 -42
 - sample/description/emotion.txt +8 -0
 - sample/description/gender.txt +2 -0
 - sample/description/genre.txt +27 -0
 - sample/description/instrument.txt +40 -0
 - sample/description/timbre.txt +7 -0
 - sample/lyric.jsonl +0 -1
 - sample/lyrics.jsonl +4 -0
 - sample/sample_prompt_audio.wav +3 -0
 
    	
        app.py
    CHANGED
    
    | 
         @@ -1,196 +1,231 @@ 
     | 
|
| 1 | 
         
            -
            import os
         
     | 
| 2 | 
         
             
            import gradio as gr
         
     | 
| 3 | 
         
             
            import json
         
     | 
| 4 | 
         
            -
            import numpy as np
         
     | 
| 5 | 
         
             
            from datetime import datetime
         
     | 
| 6 | 
         
            -
            import os
         
     | 
| 7 | 
         
             
            import yaml
         
     | 
| 8 | 
         
            -
            import sys
         
     | 
| 9 | 
         
            -
            import librosa
         
     | 
| 10 | 
         
             
            import time
         
     | 
| 11 | 
         
             
            import os.path as op
         
     | 
| 12 | 
         
            -
            APP_DIR = op.dirname(op.abspath(__file__))
         
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
             
            from download import download_model
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 15 | 
         
             
            # 下载模型
         
     | 
| 
         | 
|
| 16 | 
         
             
            download_model(APP_DIR)
         
     | 
| 17 | 
         
             
            print("Successful downloaded model.")
         
     | 
| 18 | 
         | 
| 19 | 
         
            -
             
     | 
| 20 | 
         
            -
             
     | 
| 21 | 
         
            -
            MODEL = LeVoInference(op.join(APP_DIR, "conf/infer.yaml"))
         
     | 
| 22 | 
         | 
| 23 | 
         
            -
            EXAMPLE_DESC = """female, dark, pop, sad, piano and drums, the bpm is 125."""
         
     | 
| 24 | 
         
             
            EXAMPLE_LYRICS = """
         
     | 
| 25 | 
         
             
            [intro-short]
         
     | 
| 26 | 
         | 
| 27 | 
         
             
            [verse]
         
     | 
| 28 | 
         
            -
             
     | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
            -
             
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
             
     | 
| 33 | 
         
            -
             
     | 
| 34 | 
         
            -
            如今只剩我独自回忆.
         
     | 
| 35 | 
         
            -
             
     | 
| 36 | 
         
            -
            [bridge]
         
     | 
| 37 | 
         
            -
            手机屏幕亮起.
         
     | 
| 38 | 
         
            -
            是你发来的消息.
         
     | 
| 39 | 
         
            -
            简单的几个字.
         
     | 
| 40 | 
         
            -
            却让我泪流满面.
         
     | 
| 41 | 
         
            -
            曾经的拥抱温暖.
         
     | 
| 42 | 
         
            -
            如今却变得遥远.
         
     | 
| 43 | 
         
            -
            我多想回到从前.
         
     | 
| 44 | 
         
            -
            重新拥有你的陪伴.
         
     | 
| 45 | 
         | 
| 46 | 
         
             
            [chorus]
         
     | 
| 47 | 
         
            -
             
     | 
| 48 | 
         
            -
             
     | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
             
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
             
     | 
| 53 | 
         
            -
             
     | 
| 54 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 55 | 
         | 
| 56 | 
         
             
            [outro-short]
         
     | 
| 57 | 
         
             
            """.strip()
         
     | 
| 58 | 
         | 
| 59 | 
         
            -
            with open('conf/vocab.yaml', 'r', encoding='utf-8') as file:
         
     | 
| 60 | 
         
             
                STRUCTS = yaml.safe_load(file)
         
     | 
| 61 | 
         | 
| 62 | 
         | 
| 63 | 
         
             
            # 模拟歌曲生成函数
         
     | 
| 64 | 
         
            -
            def generate_song( 
     | 
| 65 | 
         
             
                global MODEL
         
     | 
| 66 | 
         
             
                global STRUCTS
         
     | 
| 67 | 
         
             
                params = {'cfg_coef':cfg_coef, 'temperature':temperature, 'top_k':top_k}
         
     | 
| 68 | 
         
             
                params = {k:v for k,v in params.items() if v is not None}
         
     | 
| 69 | 
         
             
                sample_rate = MODEL.cfg.sample_rate
         
     | 
| 70 | 
         
            -
             
     | 
| 71 | 
         
            -
                # 生成过程
         
     | 
| 72 | 
         
            -
                print(f"Generating song with description: {description}")
         
     | 
| 73 | 
         
            -
                print(f"Lyrics provided: {lyric}")
         
     | 
| 74 | 
         | 
| 75 | 
         
             
                # 适配lyric格式
         
     | 
| 
         | 
|
| 76 | 
         
             
                lyric = lyric.replace("\n\n", " ; ")
         
     | 
| 77 | 
         
             
                for s in STRUCTS:
         
     | 
| 78 | 
         
             
                    lyric = lyric.replace(f"{s}\n", f"{s} ")
         
     | 
| 79 | 
         
            -
                lyric = lyric.replace("\n", "")
         
     | 
| 80 | 
         
             
                lyric = lyric.replace(". ; ", " ; ")
         
     | 
| 81 | 
         | 
| 82 | 
         
             
                # 适配prompt 
         
     | 
| 83 | 
         
             
                if prompt_audio is not None:
         
     | 
| 84 | 
         
            -
                     
     | 
| 85 | 
         
            -
             
     | 
| 86 | 
         
            -
             
     | 
| 
         | 
|
| 87 | 
         | 
| 88 | 
         
             
                progress(0.0, "Start Generation")
         
     | 
| 89 | 
         
             
                start = time.time()
         
     | 
| 90 | 
         | 
| 91 | 
         
            -
                audio_data = MODEL(lyric, description, prompt_audio, params).cpu().permute(1, 0).float().numpy()
         
     | 
| 92 | 
         | 
| 93 | 
         
             
                end = time.time()
         
     | 
| 94 | 
         | 
| 95 | 
         
             
                # 创建输入配置的JSON
         
     | 
| 96 | 
         
             
                input_config = {
         
     | 
| 97 | 
         
            -
                    "description": description,
         
     | 
| 98 | 
         
             
                    "lyric": lyric,
         
     | 
| 
         | 
|
| 99 | 
         
             
                    "prompt_audio": prompt_audio,
         
     | 
| 
         | 
|
| 100 | 
         
             
                    "params": params,
         
     | 
| 101 | 
         
             
                    "inference_duration": end - start,
         
     | 
| 102 | 
         
             
                    "timestamp": datetime.now().isoformat(),
         
     | 
| 103 | 
         
             
                }
         
     | 
| 
         | 
|
| 104 | 
         | 
| 105 | 
         
             
                return (sample_rate, audio_data), json.dumps(input_config, indent=2)
         
     | 
| 106 | 
         | 
| 
         | 
|
| 107 | 
         
             
            # 创建Gradio界面
         
     | 
| 108 | 
         
            -
            with gr.Blocks(title=" 
     | 
| 109 | 
         
            -
                gr.Markdown("# 🎵  
     | 
| 110 | 
         
            -
                gr.Markdown("Demo interface for the  
     | 
| 111 | 
         | 
| 112 | 
         
             
                with gr.Row():
         
     | 
| 113 | 
         
             
                    with gr.Column():
         
     | 
| 114 | 
         
            -
                        description = gr.Textbox(
         
     | 
| 115 | 
         
            -
                            label="Song Description",
         
     | 
| 116 | 
         
            -
                            placeholder="Describe the style, mood, and characteristics of the song...",
         
     | 
| 117 | 
         
            -
                            lines=1,
         
     | 
| 118 | 
         
            -
                            max_lines=2,
         
     | 
| 119 | 
         
            -
                            value=EXAMPLE_DESC,
         
     | 
| 120 | 
         
            -
                        )
         
     | 
| 121 | 
         
             
                        lyric = gr.Textbox(
         
     | 
| 122 | 
         
             
                            label="Lyrics",
         
     | 
| 123 | 
         
            -
                            placeholder="Enter the lyrics for the song...",
         
     | 
| 124 | 
         
             
                            lines=5,
         
     | 
| 125 | 
         
            -
                            max_lines= 
     | 
| 126 | 
         
             
                            value=EXAMPLE_LYRICS,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 127 | 
         
             
                        )
         
     | 
| 128 | 
         | 
| 129 | 
         
             
                        with gr.Tabs(elem_id="extra-tabs"):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 130 | 
         
             
                            with gr.Tab("Audio Prompt"):
         
     | 
| 131 | 
         
             
                                prompt_audio = gr.Audio(
         
     | 
| 132 | 
         
             
                                    label="Prompt Audio (Optional)",
         
     | 
| 133 | 
         
             
                                    type="filepath",
         
     | 
| 134 | 
         
             
                                    elem_id="audio-prompt"
         
     | 
| 135 | 
         
             
                                )
         
     | 
| 136 | 
         
            -
                            with gr.Tab(" 
     | 
| 137 | 
         
            -
                                 
     | 
| 138 | 
         
            -
                                    label=" 
     | 
| 139 | 
         
            -
                                     
     | 
| 140 | 
         
            -
                                     
     | 
| 141 | 
         
            -
                                     
     | 
| 142 | 
         
            -
                                     
     | 
| 143 | 
         
            -
                                    interactive=True,
         
     | 
| 144 | 
         
            -
                                    elem_id="cfg-coef",
         
     | 
| 145 | 
         
            -
                                )
         
     | 
| 146 | 
         
            -
                                temperature = gr.Slider(
         
     | 
| 147 | 
         
            -
                                    label="Temperature",
         
     | 
| 148 | 
         
            -
                                    minimum=0.1,
         
     | 
| 149 | 
         
            -
                                    maximum=2.0,
         
     | 
| 150 | 
         
            -
                                    step=0.1,
         
     | 
| 151 | 
         
            -
                                    value=1.0,
         
     | 
| 152 | 
         
            -
                                    interactive=True,
         
     | 
| 153 | 
         
            -
                                    elem_id="temperature",
         
     | 
| 154 | 
         
            -
                                )
         
     | 
| 155 | 
         
            -
                                top_k = gr.Slider(
         
     | 
| 156 | 
         
            -
                                    label="Top-K",
         
     | 
| 157 | 
         
            -
                                    minimum=1,
         
     | 
| 158 | 
         
            -
                                    maximum=100,
         
     | 
| 159 | 
         
            -
                                    step=1,
         
     | 
| 160 | 
         
            -
                                    value=50,
         
     | 
| 161 | 
         
            -
                                    interactive=True,
         
     | 
| 162 | 
         
            -
                                    elem_id="top_k",
         
     | 
| 163 | 
         
             
                                )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 164 | 
         
             
                        generate_btn = gr.Button("Generate Song", variant="primary")
         
     | 
| 165 | 
         | 
| 166 | 
         
             
                    with gr.Column():
         
     | 
| 167 | 
         
             
                        output_audio = gr.Audio(label="Generated Song", type="numpy")
         
     | 
| 168 | 
         
             
                        output_json = gr.JSON(label="Input Configuration")
         
     | 
| 169 | 
         | 
| 170 | 
         
            -
             
     | 
| 171 | 
         
            -
             
     | 
| 172 | 
         
            -
                    examples=[
         
     | 
| 173 | 
         
            -
             
     | 
| 174 | 
         
            -
             
     | 
| 175 | 
         
            -
                    ],
         
     | 
| 176 | 
         
            -
                    inputs=[description],
         
     | 
| 177 | 
         
            -
                    label=" 
     | 
| 178 | 
         
            -
             
     | 
| 179 | 
         
            -
             
     | 
| 180 | 
         
            -
             
     | 
| 181 | 
         
            -
                    examples=[
         
     | 
| 182 | 
         
            -
             
     | 
| 183 | 
         
            -
             
     | 
| 184 | 
         
            -
                    ],
         
     | 
| 185 | 
         
            -
                    inputs=[lyric],
         
     | 
| 186 | 
         
            -
                    label="Lyrics examples"
         
     | 
| 187 | 
         
            -
             
     | 
| 188 | 
         | 
| 189 | 
         
             
                # 生成按钮点击事件
         
     | 
| 190 | 
         
            -
             
     | 
| 191 | 
         
             
                generate_btn.click(
         
     | 
| 192 | 
         
             
                    fn=generate_song,
         
     | 
| 193 | 
         
            -
                    inputs=[ 
     | 
| 194 | 
         
             
                    outputs=[output_audio, output_json]
         
     | 
| 195 | 
         
             
                )
         
     | 
| 196 | 
         | 
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            import gradio as gr
         
     | 
| 2 | 
         
             
            import json
         
     | 
| 
         | 
|
| 3 | 
         
             
            from datetime import datetime
         
     | 
| 
         | 
|
| 4 | 
         
             
            import yaml
         
     | 
| 
         | 
|
| 
         | 
|
| 5 | 
         
             
            import time
         
     | 
| 6 | 
         
             
            import os.path as op
         
     | 
| 
         | 
|
| 
         | 
|
| 7 | 
         
             
            from download import download_model
         
     | 
| 8 | 
         
            +
            from levo_inference import LeVoInference
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
             
            # 下载模型
         
     | 
| 12 | 
         
            +
            APP_DIR = op.dirname(op.abspath(__file__))
         
     | 
| 13 | 
         
             
            download_model(APP_DIR)
         
     | 
| 14 | 
         
             
            print("Successful downloaded model.")
         
     | 
| 15 | 
         | 
| 16 | 
         
            +
            # 模型初始化
         
     | 
| 17 | 
         
            +
            MODEL = LeVoInference(op.join(APP_DIR, "ckpt/songgeneration_base_zn/"))
         
     | 
| 
         | 
|
| 18 | 
         | 
| 
         | 
|
| 19 | 
         
             
            EXAMPLE_LYRICS = """
         
     | 
| 20 | 
         
             
            [intro-short]
         
     | 
| 21 | 
         | 
| 22 | 
         
             
            [verse]
         
     | 
| 23 | 
         
            +
            雪花舞动在无尽的天际
         
     | 
| 24 | 
         
            +
            情缘如同雪花般轻轻逝去
         
     | 
| 25 | 
         
            +
            希望与真挚
         
     | 
| 26 | 
         
            +
            永不磨灭
         
     | 
| 27 | 
         
            +
            你的忧虑
         
     | 
| 28 | 
         
            +
            随风而逝
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 29 | 
         | 
| 30 | 
         
             
            [chorus]
         
     | 
| 31 | 
         
            +
            我怀抱着守护这片梦境
         
     | 
| 32 | 
         
            +
            在这世界中寻找爱与虚幻
         
     | 
| 33 | 
         
            +
            苦辣酸甜
         
     | 
| 34 | 
         
            +
            我们一起品尝
         
     | 
| 35 | 
         
            +
            在雪的光芒中
         
     | 
| 36 | 
         
            +
            紧紧相拥
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            [inst-short]
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            [verse]
         
     | 
| 41 | 
         
            +
            雪花再次在风中飘扬
         
     | 
| 42 | 
         
            +
            情愿如同雪花般消失无踪
         
     | 
| 43 | 
         
            +
            希望与真挚
         
     | 
| 44 | 
         
            +
            永不消失
         
     | 
| 45 | 
         
            +
            在痛苦与喧嚣中
         
     | 
| 46 | 
         
            +
            你找到解脱
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            [chorus]
         
     | 
| 49 | 
         
            +
            我环绕着守护这片梦境
         
     | 
| 50 | 
         
            +
            在这世界中感受爱与虚假
         
     | 
| 51 | 
         
            +
            苦辣酸甜
         
     | 
| 52 | 
         
            +
            我们一起分享
         
     | 
| 53 | 
         
            +
            在白银的光芒中
         
     | 
| 54 | 
         
            +
            我们同在
         
     | 
| 55 | 
         | 
| 56 | 
         
             
            [outro-short]
         
     | 
| 57 | 
         
             
            """.strip()
         
     | 
| 58 | 
         | 
| 59 | 
         
            +
            with open(op.join(APP_DIR, 'conf/vocab.yaml'), 'r', encoding='utf-8') as file:
         
     | 
| 60 | 
         
             
                STRUCTS = yaml.safe_load(file)
         
     | 
| 61 | 
         | 
| 62 | 
         | 
| 63 | 
         
             
            # 模拟歌曲生成函数
         
     | 
| 64 | 
         
            +
            def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_coef=None, temperature=None, top_k=None, progress=gr.Progress(track_tqdm=True)):
         
     | 
| 65 | 
         
             
                global MODEL
         
     | 
| 66 | 
         
             
                global STRUCTS
         
     | 
| 67 | 
         
             
                params = {'cfg_coef':cfg_coef, 'temperature':temperature, 'top_k':top_k}
         
     | 
| 68 | 
         
             
                params = {k:v for k,v in params.items() if v is not None}
         
     | 
| 69 | 
         
             
                sample_rate = MODEL.cfg.sample_rate
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 70 | 
         | 
| 71 | 
         
             
                # 适配lyric格式
         
     | 
| 72 | 
         
            +
                lyric = lyric.replace("[intro]", "[intro-short]").replace("[inst]", "[inst-short]").replace("[outro]", "[outro-short]")
         
     | 
| 73 | 
         
             
                lyric = lyric.replace("\n\n", " ; ")
         
     | 
| 74 | 
         
             
                for s in STRUCTS:
         
     | 
| 75 | 
         
             
                    lyric = lyric.replace(f"{s}\n", f"{s} ")
         
     | 
| 76 | 
         
            +
                lyric = lyric.replace("\n", ".")
         
     | 
| 77 | 
         
             
                lyric = lyric.replace(". ; ", " ; ")
         
     | 
| 78 | 
         | 
| 79 | 
         
             
                # 适配prompt 
         
     | 
| 80 | 
         
             
                if prompt_audio is not None:
         
     | 
| 81 | 
         
            +
                    genre = None
         
     | 
| 82 | 
         
            +
                    description = None
         
     | 
| 83 | 
         
            +
                elif description is not None and description != "":
         
     | 
| 84 | 
         
            +
                    genre = None
         
     | 
| 85 | 
         | 
| 86 | 
         
             
                progress(0.0, "Start Generation")
         
     | 
| 87 | 
         
             
                start = time.time()
         
     | 
| 88 | 
         | 
| 89 | 
         
            +
                audio_data = MODEL(lyric, description, prompt_audio, genre, op.join(APP_DIR, "ckpt/prompt.pt"), params).cpu().permute(1, 0).float().numpy()
         
     | 
| 90 | 
         | 
| 91 | 
         
             
                end = time.time()
         
     | 
| 92 | 
         | 
| 93 | 
         
             
                # 创建输入配置的JSON
         
     | 
| 94 | 
         
             
                input_config = {
         
     | 
| 
         | 
|
| 95 | 
         
             
                    "lyric": lyric,
         
     | 
| 96 | 
         
            +
                    "genre": genre,
         
     | 
| 97 | 
         
             
                    "prompt_audio": prompt_audio,
         
     | 
| 98 | 
         
            +
                    "description": description,
         
     | 
| 99 | 
         
             
                    "params": params,
         
     | 
| 100 | 
         
             
                    "inference_duration": end - start,
         
     | 
| 101 | 
         
             
                    "timestamp": datetime.now().isoformat(),
         
     | 
| 102 | 
         
             
                }
         
     | 
| 103 | 
         
            +
                print(input_config)
         
     | 
| 104 | 
         | 
| 105 | 
         
             
                return (sample_rate, audio_data), json.dumps(input_config, indent=2)
         
     | 
| 106 | 
         | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
             
            # 创建Gradio界面
         
     | 
| 109 | 
         
            +
            with gr.Blocks(title="SongGeration Demo Space") as demo:
         
     | 
| 110 | 
         
            +
                gr.Markdown("# 🎵 SongGeration Demo Space")
         
     | 
| 111 | 
         
            +
                gr.Markdown("Demo interface for the song generation model. Provide a lyrics, and optionally an audio or text prompt, to generate a custom song.")
         
     | 
| 112 | 
         | 
| 113 | 
         
             
                with gr.Row():
         
     | 
| 114 | 
         
             
                    with gr.Column():
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 115 | 
         
             
                        lyric = gr.Textbox(
         
     | 
| 116 | 
         
             
                            label="Lyrics",
         
     | 
| 
         | 
|
| 117 | 
         
             
                            lines=5,
         
     | 
| 118 | 
         
            +
                            max_lines=15,
         
     | 
| 119 | 
         
             
                            value=EXAMPLE_LYRICS,
         
     | 
| 120 | 
         
            +
                            info="Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics. Use [intro] [outro] [inst] to generate instrumental music.",
         
     | 
| 121 | 
         
            +
                            placeholder="""Lyric Format
         
     | 
| 122 | 
         
            +
            '''
         
     | 
| 123 | 
         
            +
            [structure tag]
         
     | 
| 124 | 
         
            +
            lyrics
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
            [structure tag]
         
     | 
| 127 | 
         
            +
            lyrics
         
     | 
| 128 | 
         
            +
            '''
         
     | 
| 129 | 
         
            +
            1. One paragraph represents one section, starting with a structure tag and ending with a blank line
         
     | 
| 130 | 
         
            +
            2. One line represents one lyric line, punctuation is not recommended inside the line
         
     | 
| 131 | 
         
            +
            3. Structure tags can be chosen from the following list
         
     | 
| 132 | 
         
            +
                - '[verse]'
         
     | 
| 133 | 
         
            +
                - '[chorus]'
         
     | 
| 134 | 
         
            +
                - '[bridge]'
         
     | 
| 135 | 
         
            +
                - '[intro-short]'
         
     | 
| 136 | 
         
            +
                - '[intro-medium]'
         
     | 
| 137 | 
         
            +
                - '[intro-long]'
         
     | 
| 138 | 
         
            +
                - '[outro-short]'
         
     | 
| 139 | 
         
            +
                - '[outro-medium]'
         
     | 
| 140 | 
         
            +
                - '[outro-long]'
         
     | 
| 141 | 
         
            +
                - '[inst-short]'
         
     | 
| 142 | 
         
            +
                - '[inst-medium]'
         
     | 
| 143 | 
         
            +
                - '[inst-long]'
         
     | 
| 144 | 
         
            +
                - '[silence]'
         
     | 
| 145 | 
         
            +
            """
         
     | 
| 146 | 
         
             
                        )
         
     | 
| 147 | 
         | 
| 148 | 
         
             
                        with gr.Tabs(elem_id="extra-tabs"):
         
     | 
| 149 | 
         
            +
                            with gr.Tab("Genre Select"):
         
     | 
| 150 | 
         
            +
                                genre = gr.Radio(
         
     | 
| 151 | 
         
            +
                                    choices=["Auto", "Pop", "R&B", "Dance", "Jazz", "Folk", "Rock", "Chinese Style", "Chinese Tradition", "Metal", "Reggae", "Chinese Opera"],
         
     | 
| 152 | 
         
            +
                                    label="Genre Select(Optional)",
         
     | 
| 153 | 
         
            +
                                    value="Auto",  # 默认选中第一个
         
     | 
| 154 | 
         
            +
                                    interactive=True,
         
     | 
| 155 | 
         
            +
                                    elem_id="single-select-radio"  # 便于自定义样式
         
     | 
| 156 | 
         
            +
                                )
         
     | 
| 157 | 
         
             
                            with gr.Tab("Audio Prompt"):
         
     | 
| 158 | 
         
             
                                prompt_audio = gr.Audio(
         
     | 
| 159 | 
         
             
                                    label="Prompt Audio (Optional)",
         
     | 
| 160 | 
         
             
                                    type="filepath",
         
     | 
| 161 | 
         
             
                                    elem_id="audio-prompt"
         
     | 
| 162 | 
         
             
                                )
         
     | 
| 163 | 
         
            +
                            with gr.Tab("Text Prompt"):
         
     | 
| 164 | 
         
            +
                                description = gr.Textbox(
         
     | 
| 165 | 
         
            +
                                    label="Song Description (Optional)",
         
     | 
| 166 | 
         
            +
                                    info="Describe the gender, timbre, genre, emotion, instrument and bpm of the song",
         
     | 
| 167 | 
         
            +
                                    placeholder="female, dark, pop, sad, piano and drums, the bpm is 125.",
         
     | 
| 168 | 
         
            +
                                    lines=1,
         
     | 
| 169 | 
         
            +
                                    max_lines=2
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 170 | 
         
             
                                )
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                        with gr.Accordion("Advanced Config", open=False):
         
     | 
| 173 | 
         
            +
                            cfg_coef = gr.Slider(
         
     | 
| 174 | 
         
            +
                                label="CFG Coefficient",
         
     | 
| 175 | 
         
            +
                                minimum=0.1,
         
     | 
| 176 | 
         
            +
                                maximum=3.0,
         
     | 
| 177 | 
         
            +
                                step=0.1,
         
     | 
| 178 | 
         
            +
                                value=1.5,
         
     | 
| 179 | 
         
            +
                                interactive=True,
         
     | 
| 180 | 
         
            +
                                elem_id="cfg-coef",
         
     | 
| 181 | 
         
            +
                            )
         
     | 
| 182 | 
         
            +
                            temperature = gr.Slider(
         
     | 
| 183 | 
         
            +
                                label="Temperature",
         
     | 
| 184 | 
         
            +
                                minimum=0.1,
         
     | 
| 185 | 
         
            +
                                maximum=2.0,
         
     | 
| 186 | 
         
            +
                                step=0.1,
         
     | 
| 187 | 
         
            +
                                value=0.9,
         
     | 
| 188 | 
         
            +
                                interactive=True,
         
     | 
| 189 | 
         
            +
                                elem_id="temperature",
         
     | 
| 190 | 
         
            +
                            )
         
     | 
| 191 | 
         
            +
                            top_k = gr.Slider(
         
     | 
| 192 | 
         
            +
                                label="Top-K",
         
     | 
| 193 | 
         
            +
                                minimum=1,
         
     | 
| 194 | 
         
            +
                                maximum=100,
         
     | 
| 195 | 
         
            +
                                step=1,
         
     | 
| 196 | 
         
            +
                                value=50,
         
     | 
| 197 | 
         
            +
                                interactive=True,
         
     | 
| 198 | 
         
            +
                                elem_id="top_k",
         
     | 
| 199 | 
         
            +
                            )
         
     | 
| 200 | 
         
             
                        generate_btn = gr.Button("Generate Song", variant="primary")
         
     | 
| 201 | 
         | 
| 202 | 
         
             
                    with gr.Column():
         
     | 
| 203 | 
         
             
                        output_audio = gr.Audio(label="Generated Song", type="numpy")
         
     | 
| 204 | 
         
             
                        output_json = gr.JSON(label="Input Configuration")
         
     | 
| 205 | 
         | 
| 206 | 
         
            +
                    # # 示例按钮
         
     | 
| 207 | 
         
            +
                    # examples = gr.Examples(
         
     | 
| 208 | 
         
            +
                    #     examples=[
         
     | 
| 209 | 
         
            +
                    #         ["male, bright, rock, happy, electric guitar and drums, the bpm is 150."],
         
     | 
| 210 | 
         
            +
                    #         ["female, warm, jazz, romantic, synthesizer and piano, the bpm is 100."]
         
     | 
| 211 | 
         
            +
                    #     ],
         
     | 
| 212 | 
         
            +
                    #     inputs=[description],
         
     | 
| 213 | 
         
            +
                    #     label="Text Prompt examples"
         
     | 
| 214 | 
         
            +
                    # )
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    # examples = gr.Examples(
         
     | 
| 217 | 
         
            +
                    #     examples=[
         
     | 
| 218 | 
         
            +
                    #     "[intro-medium]\n\n[verse]\n在这个疯狂的世界里\n谁不渴望一点改变\n在爱情面前\n我们都显得那么不安全\n你紧紧抱着我\n告诉我再靠近一点\n别让这璀璨的夜晚白白浪费\n我那迷茫的眼睛\n看不见未来的路\n在情感消散之前\n我们对爱的渴望永不熄灭\n你给我留下一句誓言\n想知道我们的爱是否能持续到永远\n[chorus]\n\n约定在那最后的夜晚\n不管命运如何摆布\n我们的心是否依然如初\n我会穿上红衬衫\n带着摇滚的激情\n回到我们初遇的地方\n约定在那最后的夜晚\n就算全世界都变了样\n我依然坚守诺言\n铭记这一天\n你永远是我心中的爱恋\n\n[outro-medium]\n",
         
     | 
| 219 | 
         
            +
                    #     "[intro-short]\n\n[verse]\nThrough emerald canyons where fireflies dwell\nCerulean berries kiss morning's first swell\nCrystalline dew crowns each Vitamin Dawn's confection dissolves slowly on me\nAmbrosia breezes through honeycomb vines\nNature's own candy in Fibonacci lines\n[chorus] Blueberry fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n You're under its spell\n [verse] Resin of sunlight in candied retreat\nMarmalade moonbeams melt under bare feet\nNectar spirals bloom chloroplast champagne\nPhotosynthesis sings through my veins\nChlorophyll rhythms pulse warm in my blood\nThe forest's green pharmacy floods every bud[chorus] Blueberry fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n You're under its spell\n feel the buzz\n ride the wave\n Limey me\n blueberry\n your mind's enslaved\n In the haze\n lose all time\n floating free\n feeling fine\n Blueberry\n fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n cry\n You're under its spell\n\n[outro-short]\n",
         
     | 
| 220 | 
         
            +
                    #     ],
         
     | 
| 221 | 
         
            +
                    #     inputs=[lyric],
         
     | 
| 222 | 
         
            +
                    #     label="Lyrics examples",
         
     | 
| 223 | 
         
            +
                    # )
         
     | 
| 224 | 
         | 
| 225 | 
         
             
                # 生成按钮点击事件
         
     | 
| 
         | 
|
| 226 | 
         
             
                generate_btn.click(
         
     | 
| 227 | 
         
             
                    fn=generate_song,
         
     | 
| 228 | 
         
            +
                    inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, top_k],
         
     | 
| 229 | 
         
             
                    outputs=[output_audio, output_json]
         
     | 
| 230 | 
         
             
                )
         
     | 
| 231 | 
         | 
    	
        codeclm/models/builders.py
    CHANGED
    
    | 
         @@ -16,7 +16,6 @@ from codeclm.modules.conditioners import ( 
     | 
|
| 16 | 
         
             
                BaseConditioner,
         
     | 
| 17 | 
         
             
                QwTokenizerConditioner,
         
     | 
| 18 | 
         
             
                QwTextConditioner,
         
     | 
| 19 | 
         
            -
                PhonemeTokenizerConditioner,
         
     | 
| 20 | 
         
             
                QuantizedEmbeddingConditioner,
         
     | 
| 21 | 
         
             
                ConditionerProvider,
         
     | 
| 22 | 
         
             
                ConditionFuser,
         
     | 
| 
         @@ -102,11 +101,6 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond 
     | 
|
| 102 | 
         
             
                            output_dim=output_dim,
         
     | 
| 103 | 
         
             
                            **model_args
         
     | 
| 104 | 
         
             
                        )
         
     | 
| 105 | 
         
            -
                    elif model_type == 'PhonemeTokenizer':
         
     | 
| 106 | 
         
            -
                        conditioners[str(cond)] = PhonemeTokenizerConditioner(                                                 
         
     | 
| 107 | 
         
            -
                            output_dim=output_dim,
         
     | 
| 108 | 
         
            -
                            **model_args
         
     | 
| 109 | 
         
            -
                        )
         
     | 
| 110 | 
         
             
                    elif model_type == "qt_embedding":
         
     | 
| 111 | 
         
             
                        conditioners[str(cond)] = QuantizedEmbeddingConditioner(
         
     | 
| 112 | 
         
             
                            dim=output_dim,
         
     | 
| 
         | 
|
| 16 | 
         
             
                BaseConditioner,
         
     | 
| 17 | 
         
             
                QwTokenizerConditioner,
         
     | 
| 18 | 
         
             
                QwTextConditioner,
         
     | 
| 
         | 
|
| 19 | 
         
             
                QuantizedEmbeddingConditioner,
         
     | 
| 20 | 
         
             
                ConditionerProvider,
         
     | 
| 21 | 
         
             
                ConditionFuser,
         
     | 
| 
         | 
|
| 101 | 
         
             
                            output_dim=output_dim,
         
     | 
| 102 | 
         
             
                            **model_args
         
     | 
| 103 | 
         
             
                        )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 104 | 
         
             
                    elif model_type == "qt_embedding":
         
     | 
| 105 | 
         
             
                        conditioners[str(cond)] = QuantizedEmbeddingConditioner(
         
     | 
| 106 | 
         
             
                            dim=output_dim,
         
     | 
    	
        codeclm/models/codeclm.py
    CHANGED
    
    | 
         @@ -208,29 +208,29 @@ class CodecLM: 
     | 
|
| 208 | 
         
             
                        elif melody_tokens.shape[-1] < target_melody_token_len:
         
     | 
| 209 | 
         
             
                            melody_tokens = torch.cat([melody_tokens, torch.full((1,1,target_melody_token_len - melody_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
         
     | 
| 210 | 
         
             
                    if self.seperate_tokenizer is not None:
         
     | 
| 211 | 
         
            -
                        if  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 212 | 
         
             
                            if type(vocal_wavs) == list:
         
     | 
| 213 | 
         
             
                                vocal_wavs = torch.stack(vocal_wavs, dim=0)
         
     | 
| 214 | 
         
            -
                            if bgm_wavs  
     | 
| 215 | 
         
            -
                                 
     | 
| 216 | 
         
            -
                                bgm_wavs = torch.zeros_like(vocal_wavs)
         
     | 
| 217 | 
         
            -
                                bgm_wavs[:, 0] = 1.0
         
     | 
| 218 | 
         
            -
                                bgm_wavs[:, 1:] = torch.randn_like(bgm_wavs[:, 1:])* 0.0003
         
     | 
| 219 | 
         
            -
                            else:
         
     | 
| 220 | 
         
            -
                                use_bgm = True
         
     | 
| 221 | 
         
            -
                                if type(bgm_wavs) == list:
         
     | 
| 222 | 
         
            -
                                    bgm_wavs = torch.stack(bgm_wavs, dim=0)
         
     | 
| 223 | 
         
             
                            vocal_wavs = vocal_wavs.to(self.device)
         
     | 
| 224 | 
         
             
                            bgm_wavs = bgm_wavs.to(self.device)
         
     | 
| 225 | 
         
            -
                             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 226 | 
         
             
                            assert len(vocal_tokens.shape) == len(bgm_tokens.shape) == 3, \
         
     | 
| 227 | 
         
             
                                f"vocal and bgm tokens should have a shape [B, C, T]! " \
         
     | 
| 228 | 
         
             
                                f"got vocal len={vocal_tokens.shape}, and bgm len={bgm_tokens.shape}"
         
     | 
| 229 | 
         
             
                            assert vocal_tokens.shape[-1] == bgm_tokens.shape[-1], \
         
     | 
| 230 | 
         
             
                                f"vocal and bgm tokens should have the same length! " \
         
     | 
| 231 | 
         
             
                                f"got vocal len={vocal_tokens.shape[-1]}, and bgm len={bgm_tokens.shape[-1]}"
         
     | 
| 232 | 
         
            -
                            if not use_bgm:
         
     | 
| 233 | 
         
            -
                                bgm_tokens = torch.full_like(bgm_tokens, 16385)
         
     | 
| 234 | 
         
             
                            if bgm_tokens.shape[-1] > target_melody_token_len:
         
     | 
| 235 | 
         
             
                                bgm_tokens = bgm_tokens[...,:target_melody_token_len]
         
     | 
| 236 | 
         
             
                            elif bgm_tokens.shape[-1] < target_melody_token_len:
         
     | 
| 
         @@ -239,10 +239,6 @@ class CodecLM: 
     | 
|
| 239 | 
         
             
                                vocal_tokens = vocal_tokens[...,:target_melody_token_len]
         
     | 
| 240 | 
         
             
                            elif vocal_tokens.shape[-1] < target_melody_token_len:
         
     | 
| 241 | 
         
             
                                vocal_tokens = torch.cat([vocal_tokens, torch.full((1,1,target_melody_token_len - vocal_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
         
     | 
| 242 | 
         
            -
                        else:
         
     | 
| 243 | 
         
            -
                            bgm_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
         
     | 
| 244 | 
         
            -
                            vocal_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
         
     | 
| 245 | 
         
            -
                    
         
     | 
| 246 | 
         
             
                        melody_tokens = torch.cat([melody_tokens, vocal_tokens, bgm_tokens], dim=1)
         
     | 
| 247 | 
         
             
                    assert melody_tokens.shape[-1] == target_melody_token_len
         
     | 
| 248 | 
         
             
                    audio_qt_embs = melody_tokens.long()
         
     | 
| 
         | 
|
| 208 | 
         
             
                        elif melody_tokens.shape[-1] < target_melody_token_len:
         
     | 
| 209 | 
         
             
                            melody_tokens = torch.cat([melody_tokens, torch.full((1,1,target_melody_token_len - melody_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
         
     | 
| 210 | 
         
             
                    if self.seperate_tokenizer is not None:
         
     | 
| 211 | 
         
            +
                        if bgm_wavs is None:
         
     | 
| 212 | 
         
            +
                            assert vocal_wavs is None, "vocal_wavs is not None when bgm_wavs is None"
         
     | 
| 213 | 
         
            +
                            bgm_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
         
     | 
| 214 | 
         
            +
                            vocal_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
         
     | 
| 215 | 
         
            +
                        else:
         
     | 
| 216 | 
         
            +
                            assert vocal_wavs is not None, "vocal_wavs is None when bgm_wavs is not None"
         
     | 
| 217 | 
         
             
                            if type(vocal_wavs) == list:
         
     | 
| 218 | 
         
             
                                vocal_wavs = torch.stack(vocal_wavs, dim=0)
         
     | 
| 219 | 
         
            +
                            if type(bgm_wavs) == list:
         
     | 
| 220 | 
         
            +
                                bgm_wavs = torch.stack(bgm_wavs, dim=0)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 221 | 
         
             
                            vocal_wavs = vocal_wavs.to(self.device)
         
     | 
| 222 | 
         
             
                            bgm_wavs = bgm_wavs.to(self.device)
         
     | 
| 223 | 
         
            +
                            if melody_is_wav:
         
     | 
| 224 | 
         
            +
                                vocal_tokens, bgm_tokens = self.seperate_tokenizer.encode(vocal_wavs, bgm_wavs)
         
     | 
| 225 | 
         
            +
                            else:
         
     | 
| 226 | 
         
            +
                                vocal_tokens = vocal_wavs
         
     | 
| 227 | 
         
            +
                                bgm_tokens = bgm_wavs
         
     | 
| 228 | 
         
             
                            assert len(vocal_tokens.shape) == len(bgm_tokens.shape) == 3, \
         
     | 
| 229 | 
         
             
                                f"vocal and bgm tokens should have a shape [B, C, T]! " \
         
     | 
| 230 | 
         
             
                                f"got vocal len={vocal_tokens.shape}, and bgm len={bgm_tokens.shape}"
         
     | 
| 231 | 
         
             
                            assert vocal_tokens.shape[-1] == bgm_tokens.shape[-1], \
         
     | 
| 232 | 
         
             
                                f"vocal and bgm tokens should have the same length! " \
         
     | 
| 233 | 
         
             
                                f"got vocal len={vocal_tokens.shape[-1]}, and bgm len={bgm_tokens.shape[-1]}"
         
     | 
| 
         | 
|
| 
         | 
|
| 234 | 
         
             
                            if bgm_tokens.shape[-1] > target_melody_token_len:
         
     | 
| 235 | 
         
             
                                bgm_tokens = bgm_tokens[...,:target_melody_token_len]
         
     | 
| 236 | 
         
             
                            elif bgm_tokens.shape[-1] < target_melody_token_len:
         
     | 
| 
         | 
|
| 239 | 
         
             
                                vocal_tokens = vocal_tokens[...,:target_melody_token_len]
         
     | 
| 240 | 
         
             
                            elif vocal_tokens.shape[-1] < target_melody_token_len:
         
     | 
| 241 | 
         
             
                                vocal_tokens = torch.cat([vocal_tokens, torch.full((1,1,target_melody_token_len - vocal_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 242 | 
         
             
                        melody_tokens = torch.cat([melody_tokens, vocal_tokens, bgm_tokens], dim=1)
         
     | 
| 243 | 
         
             
                    assert melody_tokens.shape[-1] == target_melody_token_len
         
     | 
| 244 | 
         
             
                    audio_qt_embs = melody_tokens.long()
         
     | 
    	
        codeclm/models/lm_levo.py
    CHANGED
    
    | 
         @@ -66,13 +66,17 @@ class LmModel(StreamingModule): 
     | 
|
| 66 | 
         
             
                             intermediate_size: int = 4096,
         
     | 
| 67 | 
         
             
                             num_heads: int = 8,
         
     | 
| 68 | 
         
             
                             norm: str = 'layer_norm', norm_first: bool = False,
         
     | 
| 69 | 
         
            -
                             bias_proj: bool = True,
         
     | 
| 70 | 
         
             
                             weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
         
     | 
| 71 | 
         
             
                             zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
         
     | 
| 72 | 
         
             
                             attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, 
         
     | 
| 73 | 
         
            -
                             lm_type = 'Llama', 
         
     | 
| 74 | 
         
             
                             num_layers=16,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 75 | 
         
             
                             cfg = None,
         
     | 
| 
         | 
|
| 76 | 
         
             
                             **kwargs):
         
     | 
| 77 | 
         
             
                    super().__init__()
         
     | 
| 78 | 
         | 
| 
         @@ -89,8 +93,6 @@ class LmModel(StreamingModule): 
     | 
|
| 89 | 
         
             
                    self.cfg = cfg
         
     | 
| 90 | 
         
             
                    self.pattern_provider = pattern_provider
         
     | 
| 91 | 
         
             
                    self.emb = nn.ModuleList([nn.Embedding(input_emb_dim, dim)])
         
     | 
| 92 | 
         
            -
                    # if 'activation' in kwargs:
         
     | 
| 93 | 
         
            -
                    #     kwargs['activation'] = get_activation_fn(kwargs['activation'])
         
     | 
| 94 | 
         | 
| 95 | 
         
             
                    model_cfg = LlamaConfig(
         
     | 
| 96 | 
         
             
                        hidden_size=dim,
         
     | 
| 
         @@ -100,12 +102,10 @@ class LmModel(StreamingModule): 
     | 
|
| 100 | 
         
             
                        num_key_value_heads = num_heads,
         
     | 
| 101 | 
         
             
                        vocab_size = self.code_size,
         
     | 
| 102 | 
         
             
                        use_cache=False,
         
     | 
| 103 | 
         
            -
                        max_position_embeddings= 
     | 
| 104 | 
         
            -
                        _flash_attn_2_enabled=True,
         
     | 
| 105 | 
         
             
                        rms_norm_eps= 1e-5,
         
     | 
| 106 | 
         
            -
                        rope_theta=  
     | 
| 107 | 
         
            -
                        use_flash_attn_2 
     | 
| 108 | 
         
            -
                        attn_implementation="flash_attention_2"
         
     | 
| 109 | 
         
             
                    )
         
     | 
| 110 | 
         | 
| 111 | 
         
             
                    self.transformer = CausalLM(model_cfg)
         
     | 
| 
         @@ -114,23 +114,22 @@ class LmModel(StreamingModule): 
     | 
|
| 114 | 
         
             
                        nn.GELU(),
         
     | 
| 115 | 
         
             
                        nn.Linear(dim, dim)
         
     | 
| 116 | 
         
             
                    )
         
     | 
| 117 | 
         
            -
                    self.layer2_emb = nn.ModuleList([nn.Embedding(input_emb_dim, dim) 
     | 
| 118 | 
         
             
                                              for _ in range(self.code_depth)])
         
     | 
| 119 | 
         
             
                    sub_model_cfg = LlamaConfig(
         
     | 
| 120 | 
         
             
                        hidden_size=dim,
         
     | 
| 121 | 
         
             
                        intermediate_size = intermediate_size,
         
     | 
| 122 | 
         
             
                        num_attention_heads = num_heads,
         
     | 
| 123 | 
         
            -
                        num_hidden_layers =  
     | 
| 124 | 
         
             
                        num_key_value_heads = num_heads,
         
     | 
| 125 | 
         
             
                        vocab_size = self.code_size,
         
     | 
| 126 | 
         
             
                        use_cache=False,
         
     | 
| 127 | 
         
            -
                        max_position_embeddings= 
     | 
| 128 | 
         
             
                        rms_norm_eps= 1e-5,
         
     | 
| 129 | 
         
            -
                        rope_theta=  
     | 
| 130 | 
         
            -
                        _flash_attn_2_enabled= 
     | 
| 131 | 
         
            -
                        use_flash_attn_2=True,
         
     | 
| 132 | 
         
            -
                        attn_implementation="flash_attention_2"
         
     | 
| 133 | 
         
             
                    )
         
     | 
| 
         | 
|
| 134 | 
         
             
                    self.transformer2 = CausalLM(sub_model_cfg)
         
     | 
| 135 | 
         
             
                    self.out_norm: tp.Optional[nn.Module] = None
         
     | 
| 136 | 
         
             
                    if norm_first:
         
     | 
| 
         @@ -208,15 +207,9 @@ class LmModel(StreamingModule): 
     | 
|
| 208 | 
         
             
                                if descriptions is not None:
         
     | 
| 209 | 
         
             
                                    attr["text"]["type_info"] = descriptions[i]
         
     | 
| 210 | 
         
             
                            attributes.append(attr)
         
     | 
| 211 | 
         
            -
                        # print("before cfg dropout", attributes)
         
     | 
| 212 | 
         
             
                        attributes = self.cfg_dropout(attributes)   # drop ALL conditions
         
     | 
| 213 | 
         
            -
                        # print("after cfg dropout", attributes)
         
     | 
| 214 | 
         
             
                        attributes = self.att_dropout(attributes)   # selectively drop some attributes (text, wav, or more fine-grained)
         
     | 
| 215 | 
         
            -
                        # print("after attribute dropout", attributes)
         
     | 
| 216 | 
         
            -
                        # attribute to discrete tokenized ids
         
     | 
| 217 | 
         
             
                        tokenized = self.condition_provider.tokenize(attributes)
         
     | 
| 218 | 
         
            -
                        # print("after tokenize", attributes)
         
     | 
| 219 | 
         
            -
                        # discrete tokenized ids to continuous embeddings
         
     | 
| 220 | 
         
             
                        condition_tensors = self.condition_provider(tokenized)
         
     | 
| 221 | 
         
             
                    else:
         
     | 
| 222 | 
         
             
                        conditions = []
         
     | 
| 
         @@ -418,6 +411,7 @@ class LmModel(StreamingModule): 
     | 
|
| 418 | 
         
             
                    assert start_offset_sequence is not None
         
     | 
| 419 | 
         
             
                    is_end = torch.zeros((B, self.code_depth, 1)).bool().to(device)
         
     | 
| 420 | 
         
             
                    ignore_tokens = audio_qt_embs[0][0]
         
     | 
| 
         | 
|
| 421 | 
         
             
                    # 5) auto-regressive sampling
         
     | 
| 422 | 
         
             
                    with self.streaming():
         
     | 
| 423 | 
         
             
                        gen_sequence_len = gen_sequence.shape[-1]  # gen_sequence shape is [B, K, S]
         
     | 
| 
         @@ -457,7 +451,6 @@ class LmModel(StreamingModule): 
     | 
|
| 457 | 
         
             
                            if torch.all(is_end):
         
     | 
| 458 | 
         
             
                                gen_sequence = gen_sequence[..., :offset+1]
         
     | 
| 459 | 
         
             
                                break
         
     | 
| 460 | 
         
            -
                            
         
     | 
| 461 | 
         
             
                            prev_offset = offset
         
     | 
| 462 | 
         | 
| 463 | 
         
             
                    # ensure sequence has been entirely filled
         
     | 
| 
         @@ -529,7 +522,7 @@ class LmModel(StreamingModule): 
     | 
|
| 529 | 
         
             
                            logits[:, q, :tmp] /= (1.1 ** q_count[:tmp])
         
     | 
| 530 | 
         | 
| 531 | 
         
             
                    # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
         
     | 
| 532 | 
         
            -
                    if(ignore_tokens is not None):
         
     | 
| 533 | 
         
             
                        logits[0][0][ignore_tokens.to(torch.int)] = float('-inf')
         
     | 
| 534 | 
         
             
                    if use_sampling and temp > 0.0:
         
     | 
| 535 | 
         
             
                        probs = torch.softmax(logits / temp, dim=-1)
         
     | 
| 
         | 
|
| 66 | 
         
             
                             intermediate_size: int = 4096,
         
     | 
| 67 | 
         
             
                             num_heads: int = 8,
         
     | 
| 68 | 
         
             
                             norm: str = 'layer_norm', norm_first: bool = False,
         
     | 
| 
         | 
|
| 69 | 
         
             
                             weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
         
     | 
| 70 | 
         
             
                             zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
         
     | 
| 71 | 
         
             
                             attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, 
         
     | 
| 
         | 
|
| 72 | 
         
             
                             num_layers=16,
         
     | 
| 73 | 
         
            +
                             max_position_embeddings: int = 8196,
         
     | 
| 74 | 
         
            +
                             max_position_embeddings_sub: int = 10000,
         
     | 
| 75 | 
         
            +
                             rope_theta: float = 100000.0,
         
     | 
| 76 | 
         
            +
                             rope_theta_sub: float = 500000.0,
         
     | 
| 77 | 
         
            +
                             num_layers_sub: int = 12,
         
     | 
| 78 | 
         
             
                             cfg = None,
         
     | 
| 79 | 
         
            +
                             use_flash_attn_2: bool = True,
         
     | 
| 80 | 
         
             
                             **kwargs):
         
     | 
| 81 | 
         
             
                    super().__init__()
         
     | 
| 82 | 
         | 
| 
         | 
|
| 93 | 
         
             
                    self.cfg = cfg
         
     | 
| 94 | 
         
             
                    self.pattern_provider = pattern_provider
         
     | 
| 95 | 
         
             
                    self.emb = nn.ModuleList([nn.Embedding(input_emb_dim, dim)])
         
     | 
| 
         | 
|
| 
         | 
|
| 96 | 
         | 
| 97 | 
         
             
                    model_cfg = LlamaConfig(
         
     | 
| 98 | 
         
             
                        hidden_size=dim,
         
     | 
| 
         | 
|
| 102 | 
         
             
                        num_key_value_heads = num_heads,
         
     | 
| 103 | 
         
             
                        vocab_size = self.code_size,
         
     | 
| 104 | 
         
             
                        use_cache=False,
         
     | 
| 105 | 
         
            +
                        max_position_embeddings=max_position_embeddings,
         
     | 
| 
         | 
|
| 106 | 
         
             
                        rms_norm_eps= 1e-5,
         
     | 
| 107 | 
         
            +
                        rope_theta= rope_theta,
         
     | 
| 108 | 
         
            +
                        _flash_attn_2_enabled=use_flash_attn_2,
         
     | 
| 
         | 
|
| 109 | 
         
             
                    )
         
     | 
| 110 | 
         | 
| 111 | 
         
             
                    self.transformer = CausalLM(model_cfg)
         
     | 
| 
         | 
|
| 114 | 
         
             
                        nn.GELU(),
         
     | 
| 115 | 
         
             
                        nn.Linear(dim, dim)
         
     | 
| 116 | 
         
             
                    )
         
     | 
| 117 | 
         
            +
                    self.layer2_emb = nn.ModuleList([nn.Embedding(input_emb_dim, dim)
         
     | 
| 118 | 
         
             
                                              for _ in range(self.code_depth)])
         
     | 
| 119 | 
         
             
                    sub_model_cfg = LlamaConfig(
         
     | 
| 120 | 
         
             
                        hidden_size=dim,
         
     | 
| 121 | 
         
             
                        intermediate_size = intermediate_size,
         
     | 
| 122 | 
         
             
                        num_attention_heads = num_heads,
         
     | 
| 123 | 
         
            +
                        num_hidden_layers = num_layers_sub,
         
     | 
| 124 | 
         
             
                        num_key_value_heads = num_heads,
         
     | 
| 125 | 
         
             
                        vocab_size = self.code_size,
         
     | 
| 126 | 
         
             
                        use_cache=False,
         
     | 
| 127 | 
         
            +
                        max_position_embeddings=max_position_embeddings_sub,
         
     | 
| 128 | 
         
             
                        rms_norm_eps= 1e-5,
         
     | 
| 129 | 
         
            +
                        rope_theta= rope_theta_sub,
         
     | 
| 130 | 
         
            +
                        _flash_attn_2_enabled=use_flash_attn_2,
         
     | 
| 
         | 
|
| 
         | 
|
| 131 | 
         
             
                    )
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
             
                    self.transformer2 = CausalLM(sub_model_cfg)
         
     | 
| 134 | 
         
             
                    self.out_norm: tp.Optional[nn.Module] = None
         
     | 
| 135 | 
         
             
                    if norm_first:
         
     | 
| 
         | 
|
| 207 | 
         
             
                                if descriptions is not None:
         
     | 
| 208 | 
         
             
                                    attr["text"]["type_info"] = descriptions[i]
         
     | 
| 209 | 
         
             
                            attributes.append(attr)
         
     | 
| 
         | 
|
| 210 | 
         
             
                        attributes = self.cfg_dropout(attributes)   # drop ALL conditions
         
     | 
| 
         | 
|
| 211 | 
         
             
                        attributes = self.att_dropout(attributes)   # selectively drop some attributes (text, wav, or more fine-grained)
         
     | 
| 
         | 
|
| 
         | 
|
| 212 | 
         
             
                        tokenized = self.condition_provider.tokenize(attributes)
         
     | 
| 
         | 
|
| 
         | 
|
| 213 | 
         
             
                        condition_tensors = self.condition_provider(tokenized)
         
     | 
| 214 | 
         
             
                    else:
         
     | 
| 215 | 
         
             
                        conditions = []
         
     | 
| 
         | 
|
| 411 | 
         
             
                    assert start_offset_sequence is not None
         
     | 
| 412 | 
         
             
                    is_end = torch.zeros((B, self.code_depth, 1)).bool().to(device)
         
     | 
| 413 | 
         
             
                    ignore_tokens = audio_qt_embs[0][0]
         
     | 
| 414 | 
         
            +
                    ignore_tokens = ignore_tokens[ignore_tokens < 16384]
         
     | 
| 415 | 
         
             
                    # 5) auto-regressive sampling
         
     | 
| 416 | 
         
             
                    with self.streaming():
         
     | 
| 417 | 
         
             
                        gen_sequence_len = gen_sequence.shape[-1]  # gen_sequence shape is [B, K, S]
         
     | 
| 
         | 
|
| 451 | 
         
             
                            if torch.all(is_end):
         
     | 
| 452 | 
         
             
                                gen_sequence = gen_sequence[..., :offset+1]
         
     | 
| 453 | 
         
             
                                break
         
     | 
| 
         | 
|
| 454 | 
         
             
                            prev_offset = offset
         
     | 
| 455 | 
         | 
| 456 | 
         
             
                    # ensure sequence has been entirely filled
         
     | 
| 
         | 
|
| 522 | 
         
             
                            logits[:, q, :tmp] /= (1.1 ** q_count[:tmp])
         
     | 
| 523 | 
         | 
| 524 | 
         
             
                    # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
         
     | 
| 525 | 
         
            +
                    if(ignore_tokens is not None and len(ignore_tokens) > 0):
         
     | 
| 526 | 
         
             
                        logits[0][0][ignore_tokens.to(torch.int)] = float('-inf')
         
     | 
| 527 | 
         
             
                    if use_sampling and temp > 0.0:
         
     | 
| 528 | 
         
             
                        probs = torch.softmax(logits / temp, dim=-1)
         
     | 
    	
        codeclm/modules/conditioners.py
    CHANGED
    
    | 
         @@ -107,173 +107,6 @@ class TextConditioner(BaseConditioner): 
     | 
|
| 107 | 
         
             
                ...
         
     | 
| 108 | 
         | 
| 109 | 
         | 
| 110 | 
         
            -
            class PhonemeTokenizerConditioner(TextConditioner):
         
     | 
| 111 | 
         
            -
                def __init__(self, 
         
     | 
| 112 | 
         
            -
                             output_dim: int, 
         
     | 
| 113 | 
         
            -
                             vocab_list,
         
     | 
| 114 | 
         
            -
                             max_len = 600, 
         
     | 
| 115 | 
         
            -
                             max_sentence_per_structure = 50,
         
     | 
| 116 | 
         
            -
                             structure_tokens=None,
         
     | 
| 117 | 
         
            -
                             structure_split_tokens=[','],
         
     | 
| 118 | 
         
            -
                             sentence_split_tokens=['.'],
         
     | 
| 119 | 
         
            -
                             mode='sum',
         
     | 
| 120 | 
         
            -
                             structure_output_dim = 64,
         
     | 
| 121 | 
         
            -
                             sentence_output_dim = 64,
         
     | 
| 122 | 
         
            -
                             max_duration = 120,
         
     | 
| 123 | 
         
            -
                             ): 
         
     | 
| 124 | 
         
            -
                    
         
     | 
| 125 | 
         
            -
                    self.vocab_list = vocab_list
         
     | 
| 126 | 
         
            -
                    self.max_len = max_len
         
     | 
| 127 | 
         
            -
                    self.mode = mode
         
     | 
| 128 | 
         
            -
                    self.max_sentence_per_structure = max_sentence_per_structure
         
     | 
| 129 | 
         
            -
                    voc_size = len(self.vocab_list)
         
     | 
| 130 | 
         
            -
             
     | 
| 131 | 
         
            -
                    if structure_tokens is None:
         
     | 
| 132 | 
         
            -
                        structure_tokens = [i for i in vocab_list if len(i) > 1 and i[0] == '[' and i[-1] == ']']
         
     | 
| 133 | 
         
            -
                    self.structure_token_ids = [vocab_list.index(i) for i in structure_tokens if i in vocab_list]
         
     | 
| 134 | 
         
            -
                    self.structure_split_token_ids = [vocab_list.index(i) for i in structure_split_tokens]
         
     | 
| 135 | 
         
            -
                    self.sentence_split_token_ids = [vocab_list.index(i) for i in sentence_split_tokens]
         
     | 
| 136 | 
         
            -
             
     | 
| 137 | 
         
            -
                    # here initialize a output_proj (nn.Embedding) layer
         
     | 
| 138 | 
         
            -
                    # By default the first vocab is "" (null)
         
     | 
| 139 | 
         
            -
                    if mode == 'sum':
         
     | 
| 140 | 
         
            -
                        content_output_dim = output_dim
         
     | 
| 141 | 
         
            -
                        sentence_output_dim = output_dim
         
     | 
| 142 | 
         
            -
                        structure_output_dim = output_dim
         
     | 
| 143 | 
         
            -
                    else:   # concat'
         
     | 
| 144 | 
         
            -
                        raise NotImplementedError("concat 模式还未实现")    
         
     | 
| 145 | 
         
            -
                        # content_output_dim = output_dim - sentence_output_dim - structure_output_dim   # by default
         
     | 
| 146 | 
         
            -
                        
         
     | 
| 147 | 
         
            -
                    super().__init__(voc_size, content_output_dim, input_token=True, padding_idx=0)
         
     | 
| 148 | 
         
            -
                    self.special_emb = nn.Embedding(voc_size, structure_output_dim, padding_idx=0)
         
     | 
| 149 | 
         
            -
                    
         
     | 
| 150 | 
         
            -
                    self.blank_emb = nn.Parameter(torch.zeros(1, output_dim), requires_grad=False)
         
     | 
| 151 | 
         
            -
             
     | 
| 152 | 
         
            -
                    # the first index is "empty structure" token
         
     | 
| 153 | 
         
            -
                    self.sentence_idx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim) 
         
     | 
| 154 | 
         
            -
                    self.sentence_reidx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim)
         
     | 
| 155 | 
         
            -
             
     | 
| 156 | 
         
            -
                    print("max_len", self.max_len)
         
     | 
| 157 | 
         
            -
                    print(self.structure_token_ids)
         
     | 
| 158 | 
         
            -
                    
         
     | 
| 159 | 
         
            -
                    self.resolution = max_duration / max_len    # e.g., 120 / 600 = 0.2s 
         
     | 
| 160 | 
         
            -
                    print(self.__class__, f"resolution = {self.resolution}")
         
     | 
| 161 | 
         
            -
                
         
     | 
| 162 | 
         
            -
                def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
         
     | 
| 163 | 
         
            -
                    inputs = []
         
     | 
| 164 | 
         
            -
                    for xx in x:
         
     | 
| 165 | 
         
            -
                        xx = '' if xx is None else xx
         
     | 
| 166 | 
         
            -
                        vocab_id = [self.vocab_list.index(item) for item in xx.split(" ") if item in self.vocab_list]
         
     | 
| 167 | 
         
            -
                        inputs.append(torch.tensor(vocab_id).long()) # [T]
         
     | 
| 168 | 
         
            -
                    return inputs
         
     | 
| 169 | 
         
            -
                        
         
     | 
| 170 | 
         
            -
                        
         
     | 
| 171 | 
         
            -
                def forward(self, batch_tokens: tp.List, structure_dur = None) -> ConditionType:
         
     | 
| 172 | 
         
            -
                    """
         
     | 
| 173 | 
         
            -
                    Encode token_id into three types of embeddings:
         
     | 
| 174 | 
         
            -
                    1) content embedding: phoneme only (or meaningful contents to be sung out) 
         
     | 
| 175 | 
         
            -
                    2) structure embedding: structure / separation embeddings, including structures (verse/chorus/...), separators (. / ,)
         
     | 
| 176 | 
         
            -
                    The two above share the same embedding layer, can be changed to separate embedding layers.
         
     | 
| 177 | 
         
            -
                    3) sentence_idx embedding (per structure): 
         
     | 
| 178 | 
         
            -
                    """
         
     | 
| 179 | 
         
            -
                    embeds_batch = []
         
     | 
| 180 | 
         
            -
                    for b in range(len(batch_tokens)):
         
     | 
| 181 | 
         
            -
                        tokens = batch_tokens[b]
         
     | 
| 182 | 
         
            -
                        content_tokens = torch.zeros_like(tokens)
         
     | 
| 183 | 
         
            -
                        special_tokens = torch.zeros_like(tokens)
         
     | 
| 184 | 
         
            -
                        sentence_idx_in_structure_tokens = torch.zeros_like(tokens) 
         
     | 
| 185 | 
         
            -
                        sentence_reidx_in_structure_tokens = torch.zeros_like(tokens)
         
     | 
| 186 | 
         
            -
             
     | 
| 187 | 
         
            -
                        current_sentence_in_structure_idx = 1
         
     | 
| 188 | 
         
            -
                        current_structure = 0
         
     | 
| 189 | 
         
            -
                        for i in range(tokens.shape[-1]):
         
     | 
| 190 | 
         
            -
                            token = tokens[i]
         
     | 
| 191 | 
         
            -
                            if token in self.structure_token_ids:       # structure token
         
     | 
| 192 | 
         
            -
                                # only update structure token, leave content and sentence index token null (default 0)
         
     | 
| 193 | 
         
            -
                                special_tokens[i] = token
         
     | 
| 194 | 
         
            -
                                content_tokens[i] = token
         
     | 
| 195 | 
         
            -
                                current_structure = token
         
     | 
| 196 | 
         
            -
                                current_sentence_in_structure_idx = 1
         
     | 
| 197 | 
         
            -
                                sentence_idx_in_structure_tokens[i] = 0
         
     | 
| 198 | 
         
            -
             
     | 
| 199 | 
         
            -
                            elif token in self.sentence_split_token_ids:    # utterance split token
         
     | 
| 200 | 
         
            -
                                # only update structure token, leave content and sentence index token null (default 0)
         
     | 
| 201 | 
         
            -
                                # add up sentence index
         
     | 
| 202 | 
         
            -
                                special_tokens[i] = current_structure
         
     | 
| 203 | 
         
            -
                                content_tokens[i] = token
         
     | 
| 204 | 
         
            -
                                sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1)
         
     | 
| 205 | 
         
            -
                                current_sentence_in_structure_idx += 1
         
     | 
| 206 | 
         
            -
             
     | 
| 207 | 
         
            -
                            elif token in self.structure_split_token_ids:    # structure split token
         
     | 
| 208 | 
         
            -
                                # update structure token (current structure), content token (current token), 
         
     | 
| 209 | 
         
            -
                                # blank index token 
         
     | 
| 210 | 
         
            -
                                content_tokens[i] = token
         
     | 
| 211 | 
         
            -
                                special_tokens[i] = current_structure
         
     | 
| 212 | 
         
            -
                                sentence_idx_in_structure_tokens[i] = sentence_idx_in_structure_tokens[i-1]
         
     | 
| 213 | 
         
            -
                            else:       # content tokens
         
     | 
| 214 | 
         
            -
                                content_tokens[i] = token
         
     | 
| 215 | 
         
            -
                                special_tokens[i] = current_structure
         
     | 
| 216 | 
         
            -
                                sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1)
         
     | 
| 217 | 
         
            -
                        # 反推
         
     | 
| 218 | 
         
            -
                        current_sentence_num = sentence_idx_in_structure_tokens[-1]
         
     | 
| 219 | 
         
            -
                        for i in range(tokens.shape[-1]-1,-1,-1):
         
     | 
| 220 | 
         
            -
                            if current_sentence_num != 0:
         
     | 
| 221 | 
         
            -
                                sentence_reidx_in_structure_tokens[i] = min(current_sentence_num + 1 - sentence_idx_in_structure_tokens[i], self.max_sentence_per_structure - 1)
         
     | 
| 222 | 
         
            -
                            if sentence_idx_in_structure_tokens[i] == 0 and i > 0:
         
     | 
| 223 | 
         
            -
                                current_sentence_num = sentence_idx_in_structure_tokens[i-1]
         
     | 
| 224 | 
         
            -
             
     | 
| 225 | 
         
            -
                        # print("tokens", tokens.max(), tokens.min())
         
     | 
| 226 | 
         
            -
                        # print("special tokens", special_tokens.max(), special_tokens.min())
         
     | 
| 227 | 
         
            -
                        # print("sentence idx in structure", sentence_idx_in_structure_tokens.max(), sentence_idx_in_structure_tokens.min())
         
     | 
| 228 | 
         
            -
                        device = self.output_proj.weight.device
         
     | 
| 229 | 
         
            -
             
     | 
| 230 | 
         
            -
                        # import pdb; pdb.set_trace()
         
     | 
| 231 | 
         
            -
                        content_embeds = self.output_proj(content_tokens.to(device))    # [T, N]
         
     | 
| 232 | 
         
            -
                        structure_embeds = self.output_proj(special_tokens.to(device))
         
     | 
| 233 | 
         
            -
                        # sentence_idx_embeds = self.sentence_idx_in_structure_emb(sentence_idx_in_structure_tokens.to(device))
         
     | 
| 234 | 
         
            -
                        sentence_idx_embeds = self.sentence_idx_in_structure_emb(sentence_idx_in_structure_tokens.to(device)) + self.sentence_reidx_in_structure_emb(sentence_reidx_in_structure_tokens.to(device))
         
     | 
| 235 | 
         
            -
             
     | 
| 236 | 
         
            -
                        if self.mode == 'sum':
         
     | 
| 237 | 
         
            -
                            embeds = content_embeds + structure_embeds + sentence_idx_embeds
         
     | 
| 238 | 
         
            -
                        else:
         
     | 
| 239 | 
         
            -
                            embeds = torch.cat((content_embeds, structure_embeds, sentence_idx_embeds), -1) # [T, N]
         
     | 
| 240 | 
         
            -
                        embeds_batch.append(embeds)
         
     | 
| 241 | 
         
            -
             
     | 
| 242 | 
         
            -
                    # set batch_size = 1, [B, T, N]
         
     | 
| 243 | 
         
            -
                    if self.max_len is not None:
         
     | 
| 244 | 
         
            -
                        max_len = self.max_len
         
     | 
| 245 | 
         
            -
                    else:
         
     | 
| 246 | 
         
            -
                        max_len = max([e.shape[0] for e in embeds_batch])
         
     | 
| 247 | 
         
            -
                    embeds, mask = self.pad_2d_tensor(embeds_batch, max_len)
         
     | 
| 248 | 
         
            -
                    
         
     | 
| 249 | 
         
            -
                    return embeds, embeds, mask
         
     | 
| 250 | 
         
            -
                
         
     | 
| 251 | 
         
            -
                
         
     | 
| 252 | 
         
            -
                def pad_2d_tensor(self, xs, max_len):
         
     | 
| 253 | 
         
            -
                    new_tensor = []
         
     | 
| 254 | 
         
            -
                    new_mask = []
         
     | 
| 255 | 
         
            -
                    for x in xs:
         
     | 
| 256 | 
         
            -
                        seq_len, dim = x.size()
         
     | 
| 257 | 
         
            -
                        pad_len = max_len - seq_len
         
     | 
| 258 | 
         
            -
             
     | 
| 259 | 
         
            -
                        if pad_len > 0:
         
     | 
| 260 | 
         
            -
                            pad_tensor = self.blank_emb.repeat(pad_len, 1).to(x.device)  # T, D
         
     | 
| 261 | 
         
            -
                            padded_tensor = torch.cat([x, pad_tensor], dim=0)
         
     | 
| 262 | 
         
            -
                            mask = torch.cat((torch.ones_like(x[:, 0]), 
         
     | 
| 263 | 
         
            -
                                              torch.zeros_like(pad_tensor[:, 0])), 0)   # T
         
     | 
| 264 | 
         
            -
                        elif pad_len < 0:
         
     | 
| 265 | 
         
            -
                            padded_tensor = x[:max_len]
         
     | 
| 266 | 
         
            -
                            mask = torch.ones_like(padded_tensor[:, 0])
         
     | 
| 267 | 
         
            -
                        else:
         
     | 
| 268 | 
         
            -
                            padded_tensor = x
         
     | 
| 269 | 
         
            -
                            mask = torch.ones_like(x[:, 0])
         
     | 
| 270 | 
         
            -
             
     | 
| 271 | 
         
            -
                        new_tensor.append(padded_tensor)
         
     | 
| 272 | 
         
            -
                        new_mask.append(mask)
         
     | 
| 273 | 
         
            -
                    # [B, T, D] & [B, T]
         
     | 
| 274 | 
         
            -
                    return torch.stack(new_tensor, 0), torch.stack(new_mask, 0)   
         
     | 
| 275 | 
         
            -
             
     | 
| 276 | 
         
            -
             
     | 
| 277 | 
         
             
            class QwTokenizerConditioner(TextConditioner):
         
     | 
| 278 | 
         
             
                def __init__(self, output_dim: int, 
         
     | 
| 279 | 
         
             
                             token_path = "",
         
     | 
| 
         | 
|
| 107 | 
         
             
                ...
         
     | 
| 108 | 
         | 
| 109 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 110 | 
         
             
            class QwTokenizerConditioner(TextConditioner):
         
     | 
| 111 | 
         
             
                def __init__(self, output_dim: int, 
         
     | 
| 112 | 
         
             
                             token_path = "",
         
     | 
    	
        codeclm/tokenizer/audio_tokenizer.py
    CHANGED
    
    | 
         @@ -92,515 +92,16 @@ class AudioTokenizer(ABC, nn.Module): 
     | 
|
| 92 | 
         
             
                        model_type = name.split('_', 1)[1]
         
     | 
| 93 | 
         
             
                        logger.info("Getting pretrained compression model from semantic model %s", model_type)
         
     | 
| 94 | 
         
             
                        model = Flow1dVAESeparate(model_type, vae_config, vae_model)
         
     | 
| 95 | 
         
            -
                    elif name.split('_')[0] == ' 
     | 
| 96 | 
         
            -
                        model_type = name.split('_', 1)[1]
         
     | 
| 97 | 
         
            -
                        logger.info("Getting pretrained compression model from semantic model %s", model_type)
         
     | 
| 98 | 
         
            -
                        model =  
     | 
| 99 | 
         
            -
                     
     | 
| 100 | 
         
            -
                         
     | 
| 101 | 
         
            -
             
     | 
| 102 | 
         
            -
             
     | 
| 103 | 
         
            -
                    elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereoLayer11':
         
     | 
| 104 | 
         
            -
                        model_type = name.split('_', 1)[1]
         
     | 
| 105 | 
         
            -
                        logger.info("Getting pretrained compression model from semantic model %s", model_type)
         
     | 
| 106 | 
         
            -
                        model = FlowVocalAndMusicDecoderStereoLayer11(model_type, mode=mode)
         
     | 
| 107 | 
         
            -
                    elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereoASRTuneLayer7':
         
     | 
| 108 | 
         
            -
                        model_type = name.split('_', 1)[1]
         
     | 
| 109 | 
         
            -
                        logger.info("Getting pretrained compression model from semantic model %s", model_type)
         
     | 
| 110 | 
         
            -
                        model = FlowVocalAndMusicDecoderStereoASRTuneLayer7(model_type, mode=mode)
         
     | 
| 111 | 
         
            -
                    elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereoASRTuneLayer7Code2':
         
     | 
| 112 | 
         
            -
                        model_type = name.split('_', 1)[1]
         
     | 
| 113 | 
         
            -
                        logger.info("Getting pretrained compression model from semantic model %s", model_type)
         
     | 
| 114 | 
         
            -
                        model = FlowVocalAndMusicDecoderStereoASRTuneLayer7Code2(model_type, mode=mode)
         
     | 
| 115 | 
         
            -
                    elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereoASRTuneLayer7Code1':
         
     | 
| 116 | 
         
            -
                        model_type = name.split('_', 1)[1]
         
     | 
| 117 | 
         
            -
                        logger.info("Getting pretrained compression model from semantic model %s", model_type)
         
     | 
| 118 | 
         
            -
                        model = FlowVocalAndMusicDecoderStereoASRTuneLayer7Code1(model_type, mode=mode)
         
     | 
| 119 | 
         
            -
                    elif name.split('_')[0] == 'Flow1dVAE2rvq':
         
     | 
| 120 | 
         
            -
                        model_type = name.split('_', 1)[1]
         
     | 
| 121 | 
         
            -
                        logger.info("Getting pretrained compression model from semantic model %s", model_type)
         
     | 
| 122 | 
         
            -
                        model = Flow1dVAE2rvq(model_type)
         
     | 
| 123 | 
         
            -
                    elif name.split('_')[0] == 'Flow1dVAE1rvq':
         
     | 
| 124 | 
         
            -
                        model_type = name.split('_', 1)[1]
         
     | 
| 125 | 
         
            -
                        logger.info("Getting pretrained compression model from semantic model %s", model_type)
         
     | 
| 126 | 
         
            -
                        model = Flow1dVAE1rvq(model_type, vae_config, vae_model)
         
     | 
| 127 | 
         
            -
                    elif name.split('_')[0] == 'Flow1dVAE4rvq':
         
     | 
| 128 | 
         
            -
                        model_type = name.split('_', 1)[1]
         
     | 
| 129 | 
         
            -
                        logger.info("Getting pretrained compression model from semantic model %s", model_type)
         
     | 
| 130 | 
         
            -
                        model = Flow1dVAE4rvq(model_type)
         
     | 
| 131 | 
         
            -
                    else:
         
     | 
| 132 | 
         
            -
                        raise NotImplementedError("{} is not implemented in models/audio_tokenizer.py".format(
         
     | 
| 133 | 
         
            -
                            name))
         
     | 
| 134 | 
         
            -
                    return model.to(device).eval()
         
     | 
| 135 | 
         
            -
                
         
     | 
| 136 | 
         
            -
             
     | 
| 137 | 
         
            -
            class FlowVocalAndMusicDecoderStereo(AudioTokenizer):
         
     | 
| 138 | 
         
            -
                def __init__(
         
     | 
| 139 | 
         
            -
                    self, 
         
     | 
| 140 | 
         
            -
                    model_type: str, 
         
     | 
| 141 | 
         
            -
                    sample_rate=48000, 
         
     | 
| 142 | 
         
            -
                    mode = 'extract',
         
     | 
| 143 | 
         
            -
                    ):
         
     | 
| 144 | 
         
            -
                    super().__init__()
         
     | 
| 145 | 
         
            -
             
     | 
| 146 | 
         
            -
                    from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo import Tango
         
     | 
| 147 | 
         
            -
                    model_path = model_type
         
     | 
| 148 | 
         
            -
                    self.mode = mode
         
     | 
| 149 | 
         
            -
                    if mode == 'extract':
         
     | 
| 150 | 
         
            -
                        self.model = Tango(model_path=model_path, layer_num=3, load_main_model=False, device='cuda')
         
     | 
| 151 | 
         
            -
                        print ("Successfully loaded checkpoint from:", model_path)
         
     | 
| 152 | 
         
            -
                    elif mode == 'inference':
         
     | 
| 153 | 
         
            -
                        self.samplerate = sample_rate
         
     | 
| 154 | 
         
            -
                        self.model = Tango(model_path=model_path, layer_num=3, load_main_model=True, device='cuda')
         
     | 
| 155 | 
         
            -
                        print ("Successfully loaded checkpoint from:", model_path)
         
     | 
| 156 | 
         
            -
                        
         
     | 
| 157 | 
         
            -
                    self.n_quantizers = 1
         
     | 
| 158 | 
         
            -
             
     | 
| 159 | 
         
            -
                def forward(self, x: torch.Tensor) :
         
     | 
| 160 | 
         
            -
                    # We don't support training with this.
         
     | 
| 161 | 
         
            -
                    raise NotImplementedError("Forward and training with DAC not supported.")
         
     | 
| 162 | 
         
            -
                
         
     | 
| 163 | 
         
            -
                @torch.no_grad()
         
     | 
| 164 | 
         
            -
                def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
         
     | 
| 165 | 
         
            -
                    if x.ndim == 2:
         
     | 
| 166 | 
         
            -
                        x = x.unsqueeze(1)
         
     | 
| 167 | 
         
            -
                    codes = self.model.sound2code(x) # [B T] -> [B N T]
         
     | 
| 168 | 
         
            -
                    return codes, None
         
     | 
| 169 | 
         
            -
             
     | 
| 170 | 
         
            -
                
         
     | 
| 171 | 
         
            -
                @torch.no_grad()    
         
     | 
| 172 | 
         
            -
                def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
         
     | 
| 173 | 
         
            -
                    wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5, 
         
     | 
| 174 | 
         
            -
                                                num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
         
     | 
| 175 | 
         
            -
                    return wav[None]
         
     | 
| 176 | 
         
            -
             
     | 
| 177 | 
         
            -
                
         
     | 
| 178 | 
         
            -
                @torch.no_grad()
         
     | 
| 179 | 
         
            -
                def decode_latent(self, codes: torch.Tensor):
         
     | 
| 180 | 
         
            -
                    """Decode from the discrete codes to continuous latent space."""
         
     | 
| 181 | 
         
            -
                    # import pdb; pdb.set_trace()
         
     | 
| 182 | 
         
            -
                    return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
         
     | 
| 183 | 
         
            -
             
     | 
| 184 | 
         
            -
                @property
         
     | 
| 185 | 
         
            -
                def channels(self) -> int:
         
     | 
| 186 | 
         
            -
                    return 2
         
     | 
| 187 | 
         
            -
             
     | 
| 188 | 
         
            -
                @property
         
     | 
| 189 | 
         
            -
                def frame_rate(self) -> float:
         
     | 
| 190 | 
         
            -
                    return 25
         
     | 
| 191 | 
         
            -
             
     | 
| 192 | 
         
            -
                @property
         
     | 
| 193 | 
         
            -
                def sample_rate(self) -> int:
         
     | 
| 194 | 
         
            -
                    return self.samplerate
         
     | 
| 195 | 
         
            -
             
     | 
| 196 | 
         
            -
                @property
         
     | 
| 197 | 
         
            -
                def cardinality(self) -> int:
         
     | 
| 198 | 
         
            -
                    return 10000
         
     | 
| 199 | 
         
            -
             
     | 
| 200 | 
         
            -
                @property
         
     | 
| 201 | 
         
            -
                def num_codebooks(self) -> int:
         
     | 
| 202 | 
         
            -
                    return self.n_quantizers
         
     | 
| 203 | 
         
            -
             
     | 
| 204 | 
         
            -
                @property
         
     | 
| 205 | 
         
            -
                def total_codebooks(self) -> int:
         
     | 
| 206 | 
         
            -
                    # return self.model.RVQ
         
     | 
| 207 | 
         
            -
                    return 1
         
     | 
| 208 | 
         
            -
             
     | 
| 209 | 
         
            -
                def set_num_codebooks(self, n: int):
         
     | 
| 210 | 
         
            -
                    """Set the active number of codebooks used by the quantizer.
         
     | 
| 211 | 
         
            -
                    """
         
     | 
| 212 | 
         
            -
                    assert n >= 1
         
     | 
| 213 | 
         
            -
                    assert n <= self.total_codebooks
         
     | 
| 214 | 
         
            -
                    self.n_quantizers = n
         
     | 
| 215 | 
         
            -
             
     | 
| 216 | 
         
            -
            class FlowVocalAndMusicDecoderStereoLayer7(AudioTokenizer):
         
     | 
| 217 | 
         
            -
                def __init__(
         
     | 
| 218 | 
         
            -
                    self, 
         
     | 
| 219 | 
         
            -
                    model_type: str = "pytorch_model_2.bin", 
         
     | 
| 220 | 
         
            -
                    sample_rate=48000, 
         
     | 
| 221 | 
         
            -
                    mode = 'extract',
         
     | 
| 222 | 
         
            -
                    ):
         
     | 
| 223 | 
         
            -
                    super().__init__()
         
     | 
| 224 | 
         
            -
             
     | 
| 225 | 
         
            -
                    from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo_layer7 import Tango
         
     | 
| 226 | 
         
            -
                    model_path = model_type
         
     | 
| 227 | 
         
            -
                    self.mode = mode
         
     | 
| 228 | 
         
            -
                    if mode == 'extract':
         
     | 
| 229 | 
         
            -
                        self.model = Tango(model_path=model_path, layer_num=7, load_main_model=False, device='cuda')
         
     | 
| 230 | 
         
            -
                        print ("Successfully loaded checkpoint from:", model_path)
         
     | 
| 231 | 
         
            -
                    elif mode == 'inference':
         
     | 
| 232 | 
         
            -
                        self.samplerate = sample_rate
         
     | 
| 233 | 
         
            -
                        self.model = Tango(model_path=model_path, layer_num=7, load_main_model=True, device='cuda')
         
     | 
| 234 | 
         
            -
                        print ("Successfully loaded checkpoint from:", model_path)
         
     | 
| 235 | 
         
            -
                        # print("Successfully loaded inference scheduler from {}".format(scheduler_name))
         
     | 
| 236 | 
         
            -
                        
         
     | 
| 237 | 
         
            -
                    self.n_quantizers = 1
         
     | 
| 238 | 
         
            -
             
     | 
| 239 | 
         
            -
                def forward(self, x: torch.Tensor) :
         
     | 
| 240 | 
         
            -
                    # We don't support training with this.
         
     | 
| 241 | 
         
            -
                    raise NotImplementedError("Forward and training with DAC not supported.")
         
     | 
| 242 | 
         
            -
                
         
     | 
| 243 | 
         
            -
                @torch.no_grad()
         
     | 
| 244 | 
         
            -
                def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
         
     | 
| 245 | 
         
            -
                    if x.ndim == 2:
         
     | 
| 246 | 
         
            -
                        x = x.unsqueeze(1)
         
     | 
| 247 | 
         
            -
                    codes = self.model.sound2code(x) # [B T] -> [B N T]
         
     | 
| 248 | 
         
            -
                    return codes, None
         
     | 
| 249 | 
         
            -
             
     | 
| 250 | 
         
            -
                
         
     | 
| 251 | 
         
            -
                @torch.no_grad()    
         
     | 
| 252 | 
         
            -
                def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
         
     | 
| 253 | 
         
            -
                    wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5, 
         
     | 
| 254 | 
         
            -
                                                num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
         
     | 
| 255 | 
         
            -
                    return wav[None]
         
     | 
| 256 | 
         
            -
             
     | 
| 257 | 
         
            -
                
         
     | 
| 258 | 
         
            -
                @torch.no_grad()
         
     | 
| 259 | 
         
            -
                def decode_latent(self, codes: torch.Tensor):
         
     | 
| 260 | 
         
            -
                    """Decode from the discrete codes to continuous latent space."""
         
     | 
| 261 | 
         
            -
                    # import pdb; pdb.set_trace()
         
     | 
| 262 | 
         
            -
                    return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
         
     | 
| 263 | 
         
            -
             
     | 
| 264 | 
         
            -
                @property
         
     | 
| 265 | 
         
            -
                def channels(self) -> int:
         
     | 
| 266 | 
         
            -
                    return 2
         
     | 
| 267 | 
         
            -
             
     | 
| 268 | 
         
            -
                @property
         
     | 
| 269 | 
         
            -
                def frame_rate(self) -> float:
         
     | 
| 270 | 
         
            -
                    return 25
         
     | 
| 271 | 
         
            -
             
     | 
| 272 | 
         
            -
                @property
         
     | 
| 273 | 
         
            -
                def sample_rate(self) -> int:
         
     | 
| 274 | 
         
            -
                    return self.samplerate
         
     | 
| 275 | 
         
            -
             
     | 
| 276 | 
         
            -
                @property
         
     | 
| 277 | 
         
            -
                def cardinality(self) -> int:
         
     | 
| 278 | 
         
            -
                    return 10000
         
     | 
| 279 | 
         
            -
             
     | 
| 280 | 
         
            -
                @property
         
     | 
| 281 | 
         
            -
                def num_codebooks(self) -> int:
         
     | 
| 282 | 
         
            -
                    return self.n_quantizers
         
     | 
| 283 | 
         
            -
             
     | 
| 284 | 
         
            -
                @property
         
     | 
| 285 | 
         
            -
                def total_codebooks(self) -> int:
         
     | 
| 286 | 
         
            -
                    # return self.model.RVQ
         
     | 
| 287 | 
         
            -
                    return 1
         
     | 
| 288 | 
         
            -
             
     | 
| 289 | 
         
            -
                def set_num_codebooks(self, n: int):
         
     | 
| 290 | 
         
            -
                    """Set the active number of codebooks used by the quantizer.
         
     | 
| 291 | 
         
            -
                    """
         
     | 
| 292 | 
         
            -
                    assert n >= 1
         
     | 
| 293 | 
         
            -
                    assert n <= self.total_codebooks
         
     | 
| 294 | 
         
            -
                    self.n_quantizers = n
         
     | 
| 295 | 
         
            -
             
     | 
| 296 | 
         
            -
            class FlowVocalAndMusicDecoderStereoASRTuneLayer7(AudioTokenizer):
         
     | 
| 297 | 
         
            -
                def __init__(
         
     | 
| 298 | 
         
            -
                    self, 
         
     | 
| 299 | 
         
            -
                    model_type: str = "model_layer7_1x4.safetensors", 
         
     | 
| 300 | 
         
            -
                    sample_rate=48000, 
         
     | 
| 301 | 
         
            -
                    mode = 'extract',
         
     | 
| 302 | 
         
            -
                    ):
         
     | 
| 303 | 
         
            -
                    super().__init__()
         
     | 
| 304 | 
         
            -
             
     | 
| 305 | 
         
            -
                    from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo_7_1x4 import Tango
         
     | 
| 306 | 
         
            -
                    model_path = model_type
         
     | 
| 307 | 
         
            -
                    self.mode = mode
         
     | 
| 308 | 
         
            -
                    if mode == 'extract':
         
     | 
| 309 | 
         
            -
                        self.model = Tango(model_path=model_path, layer_num=7, load_main_model=False, device='cuda')
         
     | 
| 310 | 
         
            -
                        print ("Successfully loaded checkpoint from:", model_path)
         
     | 
| 311 | 
         
            -
                    elif mode == 'inference':
         
     | 
| 312 | 
         
            -
                        self.samplerate = sample_rate
         
     | 
| 313 | 
         
            -
                        self.model = Tango(model_path=model_path, layer_num=7, load_main_model=True, device='cuda')
         
     | 
| 314 | 
         
            -
                        print ("Successfully loaded checkpoint from:", model_path)
         
     | 
| 315 | 
         
            -
                        # print("Successfully loaded inference scheduler from {}".format(scheduler_name))
         
     | 
| 316 | 
         
            -
                        
         
     | 
| 317 | 
         
            -
                    self.n_quantizers = 1
         
     | 
| 318 | 
         
            -
             
     | 
| 319 | 
         
            -
                def forward(self, x: torch.Tensor) :
         
     | 
| 320 | 
         
            -
                    # We don't support training with this.
         
     | 
| 321 | 
         
            -
                    raise NotImplementedError("Forward and training with DAC not supported.")
         
     | 
| 322 | 
         
            -
                
         
     | 
| 323 | 
         
            -
                @torch.no_grad()
         
     | 
| 324 | 
         
            -
                def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
         
     | 
| 325 | 
         
            -
                    if x.ndim == 2:
         
     | 
| 326 | 
         
            -
                        x = x.unsqueeze(1)
         
     | 
| 327 | 
         
            -
                    codes = self.model.sound2code(x) # [B T] -> [B N T]
         
     | 
| 328 | 
         
            -
                    return codes, None
         
     | 
| 329 | 
         
            -
             
     | 
| 330 | 
         
            -
                
         
     | 
| 331 | 
         
            -
                @torch.no_grad()    
         
     | 
| 332 | 
         
            -
                def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
         
     | 
| 333 | 
         
            -
                    wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5, 
         
     | 
| 334 | 
         
            -
                                                num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
         
     | 
| 335 | 
         
            -
                    return wav[None]
         
     | 
| 336 | 
         
            -
             
     | 
| 337 | 
         
            -
                
         
     | 
| 338 | 
         
            -
                @torch.no_grad()
         
     | 
| 339 | 
         
            -
                def decode_latent(self, codes: torch.Tensor):
         
     | 
| 340 | 
         
            -
                    """Decode from the discrete codes to continuous latent space."""
         
     | 
| 341 | 
         
            -
                    # import pdb; pdb.set_trace()
         
     | 
| 342 | 
         
            -
                    return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
         
     | 
| 343 | 
         
            -
             
     | 
| 344 | 
         
            -
                @property
         
     | 
| 345 | 
         
            -
                def channels(self) -> int:
         
     | 
| 346 | 
         
            -
                    return 2
         
     | 
| 347 | 
         
            -
             
     | 
| 348 | 
         
            -
                @property
         
     | 
| 349 | 
         
            -
                def frame_rate(self) -> float:
         
     | 
| 350 | 
         
            -
                    return 25
         
     | 
| 351 | 
         
            -
             
     | 
| 352 | 
         
            -
                @property
         
     | 
| 353 | 
         
            -
                def sample_rate(self) -> int:
         
     | 
| 354 | 
         
            -
                    return self.samplerate
         
     | 
| 355 | 
         
            -
             
     | 
| 356 | 
         
            -
                @property
         
     | 
| 357 | 
         
            -
                def cardinality(self) -> int:
         
     | 
| 358 | 
         
            -
                    return 10000
         
     | 
| 359 | 
         
            -
             
     | 
| 360 | 
         
            -
                @property
         
     | 
| 361 | 
         
            -
                def num_codebooks(self) -> int:
         
     | 
| 362 | 
         
            -
                    return self.n_quantizers
         
     | 
| 363 | 
         
            -
             
     | 
| 364 | 
         
            -
                @property
         
     | 
| 365 | 
         
            -
                def total_codebooks(self) -> int:
         
     | 
| 366 | 
         
            -
                    # return self.model.RVQ
         
     | 
| 367 | 
         
            -
                    return 1
         
     | 
| 368 | 
         
            -
             
     | 
| 369 | 
         
            -
                def set_num_codebooks(self, n: int):
         
     | 
| 370 | 
         
            -
                    """Set the active number of codebooks used by the quantizer.
         
     | 
| 371 | 
         
            -
                    """
         
     | 
| 372 | 
         
            -
                    assert n >= 1
         
     | 
| 373 | 
         
            -
                    assert n <= self.total_codebooks
         
     | 
| 374 | 
         
            -
                    self.n_quantizers = n
         
     | 
| 375 | 
         
            -
            class FlowVocalAndMusicDecoderStereoASRTuneLayer7Code2(AudioTokenizer):
         
     | 
| 376 | 
         
            -
                def __init__(
         
     | 
| 377 | 
         
            -
                    self, 
         
     | 
| 378 | 
         
            -
                    model_type: str = "model_layer7_1x2.safetensors", 
         
     | 
| 379 | 
         
            -
                    sample_rate=48000, 
         
     | 
| 380 | 
         
            -
                    mode = 'extract',
         
     | 
| 381 | 
         
            -
                    ):
         
     | 
| 382 | 
         
            -
                    super().__init__()
         
     | 
| 383 | 
         
            -
             
     | 
| 384 | 
         
            -
                    from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo_7_1x2 import Tango
         
     | 
| 385 | 
         
            -
                    model_path = model_type
         
     | 
| 386 | 
         
            -
                    self.mode = mode
         
     | 
| 387 | 
         
            -
                    if mode == 'extract':
         
     | 
| 388 | 
         
            -
                        self.model = Tango(model_path=model_path, layer_num=7, load_main_model=False, device='cuda')
         
     | 
| 389 | 
         
            -
                        print ("Successfully loaded checkpoint from:", model_path)
         
     | 
| 390 | 
         
            -
                    elif mode == 'inference':
         
     | 
| 391 | 
         
            -
                        self.samplerate = sample_rate
         
     | 
| 392 | 
         
            -
                        self.model = Tango(model_path=model_path, layer_num=7, load_main_model=True, device='cuda')
         
     | 
| 393 | 
         
            -
                        print ("Successfully loaded checkpoint from:", model_path)
         
     | 
| 394 | 
         
            -
                        # print("Successfully loaded inference scheduler from {}".format(scheduler_name))
         
     | 
| 395 | 
         
            -
                        
         
     | 
| 396 | 
         
            -
                    self.n_quantizers = 1
         
     | 
| 397 | 
         
            -
             
     | 
| 398 | 
         
            -
                def forward(self, x: torch.Tensor) :
         
     | 
| 399 | 
         
            -
                    # We don't support training with this.
         
     | 
| 400 | 
         
            -
                    raise NotImplementedError("Forward and training with DAC not supported.")
         
     | 
| 401 | 
         
            -
                
         
     | 
| 402 | 
         
            -
                @torch.no_grad()
         
     | 
| 403 | 
         
            -
                def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
         
     | 
| 404 | 
         
            -
                    if x.ndim == 2:
         
     | 
| 405 | 
         
            -
                        x = x.unsqueeze(1)
         
     | 
| 406 | 
         
            -
                    codes = self.model.sound2code(x) # [B T] -> [B N T]
         
     | 
| 407 | 
         
            -
                    return codes, None
         
     | 
| 408 | 
         
            -
             
     | 
| 409 | 
         
            -
                
         
     | 
| 410 | 
         
            -
                @torch.no_grad()    
         
     | 
| 411 | 
         
            -
                def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
         
     | 
| 412 | 
         
            -
                    wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5, 
         
     | 
| 413 | 
         
            -
                                                num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
         
     | 
| 414 | 
         
            -
                    return wav[None]
         
     | 
| 415 | 
         
            -
             
     | 
| 416 | 
         
            -
                
         
     | 
| 417 | 
         
            -
                @torch.no_grad()
         
     | 
| 418 | 
         
            -
                def decode_latent(self, codes: torch.Tensor):
         
     | 
| 419 | 
         
            -
                    """Decode from the discrete codes to continuous latent space."""
         
     | 
| 420 | 
         
            -
                    # import pdb; pdb.set_trace()
         
     | 
| 421 | 
         
            -
                    return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
         
     | 
| 422 | 
         
            -
             
     | 
| 423 | 
         
            -
                @property
         
     | 
| 424 | 
         
            -
                def channels(self) -> int:
         
     | 
| 425 | 
         
            -
                    return 2
         
     | 
| 426 | 
         
            -
             
     | 
| 427 | 
         
            -
                @property
         
     | 
| 428 | 
         
            -
                def frame_rate(self) -> float:
         
     | 
| 429 | 
         
            -
                    return 25
         
     | 
| 430 | 
         
            -
             
     | 
| 431 | 
         
            -
                @property
         
     | 
| 432 | 
         
            -
                def sample_rate(self) -> int:
         
     | 
| 433 | 
         
            -
                    return self.samplerate
         
     | 
| 434 | 
         
            -
             
     | 
| 435 | 
         
            -
                @property
         
     | 
| 436 | 
         
            -
                def cardinality(self) -> int:
         
     | 
| 437 | 
         
            -
                    return 10000
         
     | 
| 438 | 
         
            -
             
     | 
| 439 | 
         
            -
                @property
         
     | 
| 440 | 
         
            -
                def num_codebooks(self) -> int:
         
     | 
| 441 | 
         
            -
                    return self.n_quantizers
         
     | 
| 442 | 
         
            -
             
     | 
| 443 | 
         
            -
                @property
         
     | 
| 444 | 
         
            -
                def total_codebooks(self) -> int:
         
     | 
| 445 | 
         
            -
                    # return self.model.RVQ
         
     | 
| 446 | 
         
            -
                    return 1
         
     | 
| 447 | 
         
            -
             
     | 
| 448 | 
         
            -
                def set_num_codebooks(self, n: int):
         
     | 
| 449 | 
         
            -
                    """Set the active number of codebooks used by the quantizer.
         
     | 
| 450 | 
         
            -
                    """
         
     | 
| 451 | 
         
            -
                    assert n >= 1
         
     | 
| 452 | 
         
            -
                    assert n <= self.total_codebooks
         
     | 
| 453 | 
         
            -
                    self.n_quantizers = n
         
     | 
| 454 | 
         
            -
            class FlowVocalAndMusicDecoderStereoASRTuneLayer7Code1(AudioTokenizer):
         
     | 
| 455 | 
         
            -
                def __init__(
         
     | 
| 456 | 
         
            -
                    self, 
         
     | 
| 457 | 
         
            -
                    model_type: str = "model_layer7_1x1.safetensors",
         
     | 
| 458 | 
         
            -
                    sample_rate=48000, 
         
     | 
| 459 | 
         
            -
                    mode = 'extract',
         
     | 
| 460 | 
         
            -
                    ):
         
     | 
| 461 | 
         
            -
                    super().__init__()
         
     | 
| 462 | 
         
            -
             
     | 
| 463 | 
         
            -
                    from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo_7_1x1 import Tango
         
     | 
| 464 | 
         
            -
                    model_path = model_type
         
     | 
| 465 | 
         
            -
                    self.mode = mode
         
     | 
| 466 | 
         
            -
                    if mode == 'extract':
         
     | 
| 467 | 
         
            -
                        self.model = Tango(model_path=model_path, layer_num=7, load_main_model=False, device='cuda')
         
     | 
| 468 | 
         
            -
                        print ("Successfully loaded checkpoint from:", model_path)
         
     | 
| 469 | 
         
            -
                    elif mode == 'inference':
         
     | 
| 470 | 
         
            -
                        self.samplerate = sample_rate
         
     | 
| 471 | 
         
            -
                        self.model = Tango(model_path=model_path, layer_num=7, load_main_model=True, device='cuda')
         
     | 
| 472 | 
         
            -
                        print ("Successfully loaded checkpoint from:", model_path)
         
     | 
| 473 | 
         
            -
                        # print("Successfully loaded inference scheduler from {}".format(scheduler_name))
         
     | 
| 474 | 
         
            -
                        
         
     | 
| 475 | 
         
            -
                    self.n_quantizers = 1
         
     | 
| 476 | 
         
            -
             
     | 
| 477 | 
         
            -
                def forward(self, x: torch.Tensor) :
         
     | 
| 478 | 
         
            -
                    # We don't support training with this.
         
     | 
| 479 | 
         
            -
                    raise NotImplementedError("Forward and training with DAC not supported.")
         
     | 
| 480 | 
         
            -
                
         
     | 
| 481 | 
         
            -
                @torch.no_grad()
         
     | 
| 482 | 
         
            -
                def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
         
     | 
| 483 | 
         
            -
                    if x.ndim == 2:
         
     | 
| 484 | 
         
            -
                        x = x.unsqueeze(1)
         
     | 
| 485 | 
         
            -
                    codes = self.model.sound2code(x) # [B T] -> [B N T]
         
     | 
| 486 | 
         
            -
                    return codes, None
         
     | 
| 487 | 
         
            -
             
     | 
| 488 | 
         
            -
                
         
     | 
| 489 | 
         
            -
                @torch.no_grad()    
         
     | 
| 490 | 
         
            -
                def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
         
     | 
| 491 | 
         
            -
                    wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5, 
         
     | 
| 492 | 
         
            -
                                                num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
         
     | 
| 493 | 
         
            -
                    return wav[None]
         
     | 
| 494 | 
         
            -
             
     | 
| 495 | 
         
            -
                
         
     | 
| 496 | 
         
            -
                @torch.no_grad()
         
     | 
| 497 | 
         
            -
                def decode_latent(self, codes: torch.Tensor):
         
     | 
| 498 | 
         
            -
                    """Decode from the discrete codes to continuous latent space."""
         
     | 
| 499 | 
         
            -
                    # import pdb; pdb.set_trace()
         
     | 
| 500 | 
         
            -
                    return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
         
     | 
| 501 | 
         
            -
             
     | 
| 502 | 
         
            -
                @property
         
     | 
| 503 | 
         
            -
                def channels(self) -> int:
         
     | 
| 504 | 
         
            -
                    return 2
         
     | 
| 505 | 
         
            -
             
     | 
| 506 | 
         
            -
                @property
         
     | 
| 507 | 
         
            -
                def frame_rate(self) -> float:
         
     | 
| 508 | 
         
            -
                    return 25
         
     | 
| 509 | 
         
            -
             
     | 
| 510 | 
         
            -
                @property
         
     | 
| 511 | 
         
            -
                def sample_rate(self) -> int:
         
     | 
| 512 | 
         
            -
                    return self.samplerate
         
     | 
| 513 | 
         
            -
             
     | 
| 514 | 
         
            -
                @property
         
     | 
| 515 | 
         
            -
                def cardinality(self) -> int:
         
     | 
| 516 | 
         
            -
                    return 10000
         
     | 
| 517 | 
         
            -
             
     | 
| 518 | 
         
            -
                @property
         
     | 
| 519 | 
         
            -
                def num_codebooks(self) -> int:
         
     | 
| 520 | 
         
            -
                    return self.n_quantizers
         
     | 
| 521 | 
         
            -
             
     | 
| 522 | 
         
            -
                @property
         
     | 
| 523 | 
         
            -
                def total_codebooks(self) -> int:
         
     | 
| 524 | 
         
            -
                    # return self.model.RVQ
         
     | 
| 525 | 
         
            -
                    return 1
         
     | 
| 526 | 
         
            -
             
     | 
| 527 | 
         
            -
                def set_num_codebooks(self, n: int):
         
     | 
| 528 | 
         
            -
                    """Set the active number of codebooks used by the quantizer.
         
     | 
| 529 | 
         
            -
                    """
         
     | 
| 530 | 
         
            -
                    assert n >= 1
         
     | 
| 531 | 
         
            -
                    assert n <= self.total_codebooks
         
     | 
| 532 | 
         
            -
                    self.n_quantizers = n
         
     | 
| 533 | 
         
            -
            class Flow1dVAE2rvq(AudioTokenizer):
         
     | 
| 534 | 
         
            -
                def __init__(
         
     | 
| 535 | 
         
            -
                    self, 
         
     | 
| 536 | 
         
            -
                    model_type: str = "model_2.safetensors",
         
     | 
| 537 | 
         
            -
                    ):
         
     | 
| 538 | 
         
            -
                    super().__init__()
         
     | 
| 539 | 
         
            -
             
     | 
| 540 | 
         
            -
                    from codeclm.tokenizer.Flow1dVAE.generate_2rvq import Tango
         
     | 
| 541 | 
         
            -
                    model_path = model_type
         
     | 
| 542 | 
         
            -
                    self.model = Tango(model_path=model_path, rvq_num=2, device='cuda')
         
     | 
| 543 | 
         
            -
                    print ("Successfully loaded checkpoint from:", model_path)
         
     | 
| 544 | 
         
            -
             
     | 
| 545 | 
         
            -
                        
         
     | 
| 546 | 
         
            -
                    self.n_quantizers = 1
         
     | 
| 547 | 
         
            -
             
     | 
| 548 | 
         
            -
                def forward(self, x: torch.Tensor) :
         
     | 
| 549 | 
         
            -
                    # We don't support training with this.
         
     | 
| 550 | 
         
            -
                    raise NotImplementedError("Forward and training with DAC not supported.")
         
     | 
| 551 | 
         
            -
                
         
     | 
| 552 | 
         
            -
                @torch.no_grad()
         
     | 
| 553 | 
         
            -
                def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
         
     | 
| 554 | 
         
            -
                    if x.ndim == 2:
         
     | 
| 555 | 
         
            -
                        x = x.unsqueeze(1)
         
     | 
| 556 | 
         
            -
                    codes = self.model.sound2code(x) # [B T] -> [B N T]
         
     | 
| 557 | 
         
            -
                    return codes, None
         
     | 
| 558 | 
         
            -
             
     | 
| 559 | 
         
            -
                
         
     | 
| 560 | 
         
            -
                @torch.no_grad()    
         
     | 
| 561 | 
         
            -
                def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
         
     | 
| 562 | 
         
            -
                    wav = self.model.code2sound(codes, prompt=prompt, guidance_scale=1.5, 
         
     | 
| 563 | 
         
            -
                                                num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
         
     | 
| 564 | 
         
            -
                    return wav[None]
         
     | 
| 565 | 
         
            -
             
     | 
| 566 | 
         | 
| 567 | 
         
            -
                @torch.no_grad()
         
     | 
| 568 | 
         
            -
                def decode_latent(self, codes: torch.Tensor):
         
     | 
| 569 | 
         
            -
                    """Decode from the discrete codes to continuous latent space."""
         
     | 
| 570 | 
         
            -
                    # import pdb; pdb.set_trace()
         
     | 
| 571 | 
         
            -
                    return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
         
     | 
| 572 | 
         
            -
             
     | 
| 573 | 
         
            -
                @property
         
     | 
| 574 | 
         
            -
                def channels(self) -> int:
         
     | 
| 575 | 
         
            -
                    return 2
         
     | 
| 576 | 
         
            -
             
     | 
| 577 | 
         
            -
                @property
         
     | 
| 578 | 
         
            -
                def frame_rate(self) -> float:
         
     | 
| 579 | 
         
            -
                    return 25
         
     | 
| 580 | 
         
            -
             
     | 
| 581 | 
         
            -
                @property
         
     | 
| 582 | 
         
            -
                def sample_rate(self) -> int:
         
     | 
| 583 | 
         
            -
                    return self.samplerate
         
     | 
| 584 | 
         
            -
             
     | 
| 585 | 
         
            -
                @property
         
     | 
| 586 | 
         
            -
                def cardinality(self) -> int:
         
     | 
| 587 | 
         
            -
                    return 10000
         
     | 
| 588 | 
         
            -
             
     | 
| 589 | 
         
            -
                @property
         
     | 
| 590 | 
         
            -
                def num_codebooks(self) -> int:
         
     | 
| 591 | 
         
            -
                    return self.n_quantizers
         
     | 
| 592 | 
         
            -
             
     | 
| 593 | 
         
            -
                @property
         
     | 
| 594 | 
         
            -
                def total_codebooks(self) -> int:
         
     | 
| 595 | 
         
            -
                    # return self.model.RVQ
         
     | 
| 596 | 
         
            -
                    return 1
         
     | 
| 597 | 
         | 
| 598 | 
         
            -
                def set_num_codebooks(self, n: int):
         
     | 
| 599 | 
         
            -
                    """Set the active number of codebooks used by the quantizer.
         
     | 
| 600 | 
         
            -
                    """
         
     | 
| 601 | 
         
            -
                    assert n >= 1
         
     | 
| 602 | 
         
            -
                    assert n <= self.total_codebooks
         
     | 
| 603 | 
         
            -
                    self.n_quantizers = n
         
     | 
| 604 | 
         
             
            class Flow1dVAE1rvq(AudioTokenizer):
         
     | 
| 605 | 
         
             
                def __init__(
         
     | 
| 606 | 
         
             
                    self, 
         
     | 
| 
         @@ -674,78 +175,6 @@ class Flow1dVAE1rvq(AudioTokenizer): 
     | 
|
| 674 | 
         
             
                    assert n >= 1
         
     | 
| 675 | 
         
             
                    assert n <= self.total_codebooks
         
     | 
| 676 | 
         
             
                    self.n_quantizers = n
         
     | 
| 677 | 
         
            -
            class Flow1dVAE4rvq(AudioTokenizer):
         
     | 
| 678 | 
         
            -
                def __init__(
         
     | 
| 679 | 
         
            -
                    self, 
         
     | 
| 680 | 
         
            -
                    model_type: str = "model_2.safetensors",
         
     | 
| 681 | 
         
            -
                    ):
         
     | 
| 682 | 
         
            -
                    super().__init__()
         
     | 
| 683 | 
         
            -
             
     | 
| 684 | 
         
            -
                    from codeclm.tokenizer.Flow1dVAE.generate_4rvq import Tango
         
     | 
| 685 | 
         
            -
                    model_path = model_type
         
     | 
| 686 | 
         
            -
                    self.model = Tango(model_path=model_path, rvq_num=4, device='cuda')
         
     | 
| 687 | 
         
            -
                    print ("Successfully loaded checkpoint from:", model_path)
         
     | 
| 688 | 
         
            -
             
     | 
| 689 | 
         
            -
                        
         
     | 
| 690 | 
         
            -
                    self.n_quantizers = 1
         
     | 
| 691 | 
         
            -
             
     | 
| 692 | 
         
            -
                def forward(self, x: torch.Tensor) :
         
     | 
| 693 | 
         
            -
                    # We don't support training with this.
         
     | 
| 694 | 
         
            -
                    raise NotImplementedError("Forward and training with DAC not supported.")
         
     | 
| 695 | 
         
            -
                
         
     | 
| 696 | 
         
            -
                @torch.no_grad()
         
     | 
| 697 | 
         
            -
                def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
         
     | 
| 698 | 
         
            -
                    if x.ndim == 2:
         
     | 
| 699 | 
         
            -
                        x = x.unsqueeze(1)
         
     | 
| 700 | 
         
            -
                    codes = self.model.sound2code(x) # [B T] -> [B N T]
         
     | 
| 701 | 
         
            -
                    return codes, None
         
     | 
| 702 | 
         
            -
             
     | 
| 703 | 
         
            -
                
         
     | 
| 704 | 
         
            -
                @torch.no_grad()    
         
     | 
| 705 | 
         
            -
                def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
         
     | 
| 706 | 
         
            -
                    wav = self.model.code2sound(codes, prompt=prompt, guidance_scale=1.5, 
         
     | 
| 707 | 
         
            -
                                                num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
         
     | 
| 708 | 
         
            -
                    return wav[None]
         
     | 
| 709 | 
         
            -
             
     | 
| 710 | 
         
            -
                
         
     | 
| 711 | 
         
            -
                @torch.no_grad()
         
     | 
| 712 | 
         
            -
                def decode_latent(self, codes: torch.Tensor):
         
     | 
| 713 | 
         
            -
                    """Decode from the discrete codes to continuous latent space."""
         
     | 
| 714 | 
         
            -
                    # import pdb; pdb.set_trace()
         
     | 
| 715 | 
         
            -
                    return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
         
     | 
| 716 | 
         
            -
             
     | 
| 717 | 
         
            -
                @property
         
     | 
| 718 | 
         
            -
                def channels(self) -> int:
         
     | 
| 719 | 
         
            -
                    return 2
         
     | 
| 720 | 
         
            -
             
     | 
| 721 | 
         
            -
                @property
         
     | 
| 722 | 
         
            -
                def frame_rate(self) -> float:
         
     | 
| 723 | 
         
            -
                    return 25
         
     | 
| 724 | 
         
            -
             
     | 
| 725 | 
         
            -
                @property
         
     | 
| 726 | 
         
            -
                def sample_rate(self) -> int:
         
     | 
| 727 | 
         
            -
                    return self.samplerate
         
     | 
| 728 | 
         
            -
             
     | 
| 729 | 
         
            -
                @property
         
     | 
| 730 | 
         
            -
                def cardinality(self) -> int:
         
     | 
| 731 | 
         
            -
                    return 10000
         
     | 
| 732 | 
         
            -
             
     | 
| 733 | 
         
            -
                @property
         
     | 
| 734 | 
         
            -
                def num_codebooks(self) -> int:
         
     | 
| 735 | 
         
            -
                    return self.n_quantizers
         
     | 
| 736 | 
         
            -
             
     | 
| 737 | 
         
            -
                @property
         
     | 
| 738 | 
         
            -
                def total_codebooks(self) -> int:
         
     | 
| 739 | 
         
            -
                    # return self.model.RVQ
         
     | 
| 740 | 
         
            -
                    return 1
         
     | 
| 741 | 
         
            -
             
     | 
| 742 | 
         
            -
                def set_num_codebooks(self, n: int):
         
     | 
| 743 | 
         
            -
                    """Set the active number of codebooks used by the quantizer.
         
     | 
| 744 | 
         
            -
                    """
         
     | 
| 745 | 
         
            -
                    assert n >= 1
         
     | 
| 746 | 
         
            -
                    assert n <= self.total_codebooks
         
     | 
| 747 | 
         
            -
                    self.n_quantizers = n
         
     | 
| 748 | 
         
            -
             
     | 
| 749 | 
         | 
| 750 | 
         | 
| 751 | 
         
             
            class Flow1dVAESeparate(AudioTokenizer):
         
     | 
| 
         @@ -822,86 +251,3 @@ class Flow1dVAESeparate(AudioTokenizer): 
     | 
|
| 822 | 
         
             
                    assert n >= 1
         
     | 
| 823 | 
         
             
                    assert n <= self.total_codebooks
         
     | 
| 824 | 
         
             
                    self.n_quantizers = n
         
     | 
| 825 | 
         
            -
             
     | 
| 826 | 
         
            -
            class FlowVocalAndMusicDecoderStereoLayer11(AudioTokenizer):
         
     | 
| 827 | 
         
            -
                def __init__(
         
     | 
| 828 | 
         
            -
                    self, 
         
     | 
| 829 | 
         
            -
                    model_type: str = "layer11_ckpt.pth", 
         
     | 
| 830 | 
         
            -
                    sample_rate=48000, 
         
     | 
| 831 | 
         
            -
                    mode = 'extract',
         
     | 
| 832 | 
         
            -
                    ):
         
     | 
| 833 | 
         
            -
                    super().__init__()
         
     | 
| 834 | 
         
            -
             
     | 
| 835 | 
         
            -
                    from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo_11 import Tango
         
     | 
| 836 | 
         
            -
                    model_path = model_type
         
     | 
| 837 | 
         
            -
                    self.mode = mode
         
     | 
| 838 | 
         
            -
                    if mode == 'extract':
         
     | 
| 839 | 
         
            -
                        self.model = Tango(model_path=model_path, layer_num=11, load_main_model=False, device='cuda')
         
     | 
| 840 | 
         
            -
                        print ("Successfully loaded checkpoint from:", model_path)
         
     | 
| 841 | 
         
            -
                    elif mode == 'inference':
         
     | 
| 842 | 
         
            -
                        self.samplerate = sample_rate
         
     | 
| 843 | 
         
            -
                        self.model = Tango(model_path=model_path, layer_num=11, load_main_model=True, device='cuda')
         
     | 
| 844 | 
         
            -
                        print ("Successfully loaded checkpoint from:", model_path)
         
     | 
| 845 | 
         
            -
                        # print("Successfully loaded inference scheduler from {}".format(scheduler_name))
         
     | 
| 846 | 
         
            -
                        
         
     | 
| 847 | 
         
            -
                    self.n_quantizers = 1
         
     | 
| 848 | 
         
            -
             
     | 
| 849 | 
         
            -
                def forward(self, x: torch.Tensor) :
         
     | 
| 850 | 
         
            -
                    # We don't support training with this.
         
     | 
| 851 | 
         
            -
                    raise NotImplementedError("Forward and training with DAC not supported.")
         
     | 
| 852 | 
         
            -
                
         
     | 
| 853 | 
         
            -
                @torch.no_grad()
         
     | 
| 854 | 
         
            -
                def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
         
     | 
| 855 | 
         
            -
                    if x.ndim == 2:
         
     | 
| 856 | 
         
            -
                        x = x.unsqueeze(1)
         
     | 
| 857 | 
         
            -
                    codes = self.model.sound2code(x) # [B T] -> [B N T]
         
     | 
| 858 | 
         
            -
                    return codes, None
         
     | 
| 859 | 
         
            -
             
     | 
| 860 | 
         
            -
                
         
     | 
| 861 | 
         
            -
                @torch.no_grad()    
         
     | 
| 862 | 
         
            -
                def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
         
     | 
| 863 | 
         
            -
                    wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5, 
         
     | 
| 864 | 
         
            -
                                                num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
         
     | 
| 865 | 
         
            -
                    return wav[None]
         
     | 
| 866 | 
         
            -
             
     | 
| 867 | 
         
            -
                
         
     | 
| 868 | 
         
            -
                @torch.no_grad()
         
     | 
| 869 | 
         
            -
                def decode_latent(self, codes: torch.Tensor):
         
     | 
| 870 | 
         
            -
                    """Decode from the discrete codes to continuous latent space."""
         
     | 
| 871 | 
         
            -
                    # import pdb; pdb.set_trace()
         
     | 
| 872 | 
         
            -
                    return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
         
     | 
| 873 | 
         
            -
             
     | 
| 874 | 
         
            -
                @property
         
     | 
| 875 | 
         
            -
                def channels(self) -> int:
         
     | 
| 876 | 
         
            -
                    return 2
         
     | 
| 877 | 
         
            -
             
     | 
| 878 | 
         
            -
                @property
         
     | 
| 879 | 
         
            -
                def frame_rate(self) -> float:
         
     | 
| 880 | 
         
            -
                    return 25
         
     | 
| 881 | 
         
            -
             
     | 
| 882 | 
         
            -
                @property
         
     | 
| 883 | 
         
            -
                def sample_rate(self) -> int:
         
     | 
| 884 | 
         
            -
                    return self.samplerate
         
     | 
| 885 | 
         
            -
             
     | 
| 886 | 
         
            -
                @property
         
     | 
| 887 | 
         
            -
                def cardinality(self) -> int:
         
     | 
| 888 | 
         
            -
                    return 10000
         
     | 
| 889 | 
         
            -
             
     | 
| 890 | 
         
            -
                @property
         
     | 
| 891 | 
         
            -
                def num_codebooks(self) -> int:
         
     | 
| 892 | 
         
            -
                    return self.n_quantizers
         
     | 
| 893 | 
         
            -
             
     | 
| 894 | 
         
            -
                @property
         
     | 
| 895 | 
         
            -
                def total_codebooks(self) -> int:
         
     | 
| 896 | 
         
            -
                    # return self.model.RVQ
         
     | 
| 897 | 
         
            -
                    return 1
         
     | 
| 898 | 
         
            -
             
     | 
| 899 | 
         
            -
                def set_num_codebooks(self, n: int):
         
     | 
| 900 | 
         
            -
                    """Set the active number of codebooks used by the quantizer.
         
     | 
| 901 | 
         
            -
                    """
         
     | 
| 902 | 
         
            -
                    assert n >= 1
         
     | 
| 903 | 
         
            -
                    assert n <= self.total_codebooks
         
     | 
| 904 | 
         
            -
                    self.n_quantizers = n
         
     | 
| 905 | 
         
            -
                
         
     | 
| 906 | 
         
            -
                
         
     | 
| 907 | 
         
            -
                
         
     | 
| 
         | 
|
| 92 | 
         
             
                        model_type = name.split('_', 1)[1]
         
     | 
| 93 | 
         
             
                        logger.info("Getting pretrained compression model from semantic model %s", model_type)
         
     | 
| 94 | 
         
             
                        model = Flow1dVAESeparate(model_type, vae_config, vae_model)
         
     | 
| 95 | 
         
            +
                    elif name.split('_')[0] == 'Flow1dVAE1rvq':
         
     | 
| 96 | 
         
            +
                        model_type = name.split('_', 1)[1]
         
     | 
| 97 | 
         
            +
                        logger.info("Getting pretrained compression model from semantic model %s", model_type)
         
     | 
| 98 | 
         
            +
                        model = Flow1dVAE1rvq(model_type, vae_config, vae_model)
         
     | 
| 99 | 
         
            +
                    else:
         
     | 
| 100 | 
         
            +
                        raise NotImplementedError("{} is not implemented in models/audio_tokenizer.py".format(
         
     | 
| 101 | 
         
            +
                            name))
         
     | 
| 102 | 
         
            +
                    return model.to(device).eval()
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 103 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 104 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 105 | 
         
             
            class Flow1dVAE1rvq(AudioTokenizer):
         
     | 
| 106 | 
         
             
                def __init__(
         
     | 
| 107 | 
         
             
                    self, 
         
     | 
| 
         | 
|
| 175 | 
         
             
                    assert n >= 1
         
     | 
| 176 | 
         
             
                    assert n <= self.total_codebooks
         
     | 
| 177 | 
         
             
                    self.n_quantizers = n
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 178 | 
         | 
| 179 | 
         | 
| 180 | 
         
             
            class Flow1dVAESeparate(AudioTokenizer):
         
     | 
| 
         | 
|
| 251 | 
         
             
                    assert n >= 1
         
     | 
| 252 | 
         
             
                    assert n <= self.total_codebooks
         
     | 
| 253 | 
         
             
                    self.n_quantizers = n
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        codeclm/trainer/codec_song_pl.py
    CHANGED
    
    | 
         @@ -26,7 +26,7 @@ os.environ['TOKENIZERS_PARALLELISM'] = "false" 
     | 
|
| 26 | 
         | 
| 27 | 
         | 
| 28 | 
         
             
            class CodecLM_PL(pl.LightningModule):
         
     | 
| 29 | 
         
            -
                def __init__(self, cfg):
         
     | 
| 30 | 
         
             
                    super().__init__()
         
     | 
| 31 | 
         | 
| 32 | 
         
             
                    self.cfg = cfg
         
     | 
| 
         @@ -46,30 +46,12 @@ class CodecLM_PL(pl.LightningModule): 
     | 
|
| 46 | 
         
             
                    # 2) Build LM
         
     | 
| 47 | 
         
             
                    self.audiolm = builders.get_lm_model(self.cfg)
         
     | 
| 48 | 
         
             
                    print(self.audiolm)
         
     | 
| 49 | 
         
            -
                    # 输出参数量
         
     | 
| 50 | 
         
            -
                    print('Number of parameters: ', sum(p.numel() for p in self.audiolm.parameters()))
         
     | 
| 51 | 
         
             
                    # 3) Load pretrained checkpoint (if any)
         
     | 
| 52 | 
         
            -
                     
     | 
| 53 | 
         
            -
             
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
             
     | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
            -
                        print("successfully load deepspeed pretrained model {}".format(self.cfg.pretrained.deepspeed_checkpoint))
         
     | 
| 58 | 
         
            -
                        self.missing = missing
         
     | 
| 59 | 
         
            -
                    else:
         
     | 
| 60 | 
         
            -
                        self.missing = []
         
     | 
| 61 | 
         
            -
                    # 如果cfg参数中有lora
         
     | 
| 62 | 
         
            -
                    if hasattr(self.cfg, 'lora'):
         
     | 
| 63 | 
         
            -
                        perf_config = LoraConfig(
         
     | 
| 64 | 
         
            -
                            r = self.cfg.lora.r,
         
     | 
| 65 | 
         
            -
                            lora_alpha = self.cfg.lora.lora_alpha,
         
     | 
| 66 | 
         
            -
                            target_modules = self.cfg.lora.target_modules,
         
     | 
| 67 | 
         
            -
                            lora_dropout = self.cfg.lora.lora_dropout,
         
     | 
| 68 | 
         
            -
                            bias = self.cfg.lora.bias,
         
     | 
| 69 | 
         
            -
                            task_type = self.cfg.lora.task_type,
         
     | 
| 70 | 
         
            -
                        )
         
     | 
| 71 | 
         
            -
                        self.audiolm = get_peft_model(self.audiolm, perf_config)
         
     | 
| 72 | 
         
            -
                    
         
     | 
| 73 | 
         
             
                    # 4) Build metrics
         
     | 
| 74 | 
         
             
                    self.val_steps = []
         
     | 
| 75 | 
         
             
                    self.train_slide_acc = []
         
     | 
| 
         @@ -113,32 +95,6 @@ class CodecLM_PL(pl.LightningModule): 
     | 
|
| 113 | 
         
             
                    x = torch.where(mask_3d, x, end_id+1)
         
     | 
| 114 | 
         
             
                    return x, mask_3d
         
     | 
| 115 | 
         | 
| 116 | 
         
            -
                @torch.no_grad()
         
     | 
| 117 | 
         
            -
                def preprocess_batch(self, batch):  # this function is usually called during training
         
     | 
| 118 | 
         
            -
                    # 处理 dataloader 返回的数据
         
     | 
| 119 | 
         
            -
                    audio, text_lyric, time_stamp, structure_dur, prompt_audio, structure_labels = batch
         
     | 
| 120 | 
         
            -
             
     | 
| 121 | 
         
            -
                    dur, valid_st, valid_et = zip(*time_stamp)
         
     | 
| 122 | 
         
            -
                    
         
     | 
| 123 | 
         
            -
                    if self.audio_tokenizer is not None:
         
     | 
| 124 | 
         
            -
                        # only used in inference
         
     | 
| 125 | 
         
            -
                        self.audio_tokenizer.eval()
         
     | 
| 126 | 
         
            -
                        with torch.no_grad():
         
     | 
| 127 | 
         
            -
                            with torch.cuda.amp.autocast(enabled=False):
         
     | 
| 128 | 
         
            -
                                audio_tokens, scale = self.audio_tokenizer.encode(audio)
         
     | 
| 129 | 
         
            -
                            audio_tokens = audio_tokens[:,:self.cfg.lm.code_depth,:]
         
     | 
| 130 | 
         
            -
                            audio_tokens = audio_tokens.long()
         
     | 
| 131 | 
         
            -
                    else:
         
     | 
| 132 | 
         
            -
                        audio_tokens = audio.long()
         
     | 
| 133 | 
         
            -
                    
         
     | 
| 134 | 
         
            -
                    token_dur = (torch.Tensor(dur) * self.cfg.audio_tokenizer_frame_rate).int()
         
     | 
| 135 | 
         
            -
                    audio_tokens, audio_padding_mask = self.generate_mask_and_end_token(audio_tokens, token_dur, 
         
     | 
| 136 | 
         
            -
                                                                                        end_id=self.audiolm.eos_token_id)
         
     | 
| 137 | 
         
            -
                    condition_tensors = self.audiolm.prepare_condition_tensors(batch_size=len(text_lyric),
         
     | 
| 138 | 
         
            -
                                                                               text=text_lyric, audio_qt_emb=prompt_audio)
         
     | 
| 139 | 
         
            -
             
     | 
| 140 | 
         
            -
                    return condition_tensors, audio_tokens, audio_padding_mask
         
     | 
| 141 | 
         
            -
             
     | 
| 142 | 
         
             
                def get_time(self):
         
     | 
| 143 | 
         
             
                    # 获取当前的日期和时间
         
     | 
| 144 | 
         
             
                    now = datetime.now()
         
     | 
| 
         @@ -147,506 +103,6 @@ class CodecLM_PL(pl.LightningModule): 
     | 
|
| 147 | 
         
             
                    formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f")
         
     | 
| 148 | 
         
             
                    return formatted_now
         
     | 
| 149 | 
         | 
| 150 | 
         
            -
                def training_step(self, batch, batch_idx):
         
     | 
| 151 | 
         
            -
                    # 1) data processing
         
     | 
| 152 | 
         
            -
                    condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch)
         
     | 
| 153 | 
         
            -
                    
         
     | 
| 154 | 
         
            -
                    # 2) compute model predictions (model forward)
         
     | 
| 155 | 
         
            -
                    model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors, 
         
     | 
| 156 | 
         
            -
                                                                    training_steps=self.global_step)  # this input can be ignored        
         
     | 
| 157 | 
         
            -
                    logits = model_output.logits.float()
         
     | 
| 158 | 
         
            -
                    mask = padding_mask & model_output.mask
         
     | 
| 159 | 
         
            -
                    
         
     | 
| 160 | 
         
            -
                    # 3) compute loss (float)
         
     | 
| 161 | 
         
            -
                    with torch.cuda.amp.autocast(enabled=False):
         
     | 
| 162 | 
         
            -
                        ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
         
     | 
| 163 | 
         
            -
                    
         
     | 
| 164 | 
         
            -
                    total_loss = ce
         
     | 
| 165 | 
         
            -
                    if torch.isnan(total_loss):
         
     | 
| 166 | 
         
            -
                        print(self.trainer.global_rank, ce, padding_mask, batch[1])
         
     | 
| 167 | 
         
            -
                        print('--------------------------------------------------------------')
         
     | 
| 168 | 
         
            -
                        return None
         
     | 
| 169 | 
         
            -
                        # torchaudio.save("error_rank{}.wav".format(self.trainer.global_rank), batch[0][:,0].cpu(), 24000)
         
     | 
| 170 | 
         
            -
                        # import pdb; pdb.set_trace()
         
     | 
| 171 | 
         
            -
                    # 4) compute metrics and log
         
     | 
| 172 | 
         
            -
                    metrics = {}
         
     | 
| 173 | 
         
            -
                    self.log('ce', ce, prog_bar=True) 
         
     | 
| 174 | 
         
            -
                    metrics['ppl'] = torch.exp(ce)
         
     | 
| 175 | 
         
            -
                    for k, ce_q in enumerate(ce_per_codebook):
         
     | 
| 176 | 
         
            -
                        metrics[f'ce_q{k + 1}'] = ce_q
         
     | 
| 177 | 
         
            -
                        metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)
         
     | 
| 178 | 
         
            -
             
     | 
| 179 | 
         
            -
                    masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size)
         
     | 
| 180 | 
         
            -
                    metrics['acc'] = []
         
     | 
| 181 | 
         
            -
                    for k in range(self.audiolm.code_depth):
         
     | 
| 182 | 
         
            -
                        metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), 
         
     | 
| 183 | 
         
            -
                                                                      masked_labels[:, k]).item())
         
     | 
| 184 | 
         
            -
                    metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])).item()
         
     | 
| 185 | 
         
            -
             
     | 
| 186 | 
         
            -
                    self.train_steps.append({'ce': ce.detach().cpu().item(), 'acc': metrics['acc']})        
         
     | 
| 187 | 
         
            -
                    self.log('train_acc', metrics['acc']+1e-8, prog_bar=True)
         
     | 
| 188 | 
         
            -
                    self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], prog_bar=True)  
         
     | 
| 189 | 
         
            -
                    self.log_dict(metrics)
         
     | 
| 190 | 
         
            -
             
     | 
| 191 | 
         
            -
                    return total_loss
         
     | 
| 192 | 
         
            -
                
         
     | 
| 193 | 
         
            -
                @torch.no_grad()
         
     | 
| 194 | 
         
            -
                def validation_step(self, batch, batch_idx):
         
     | 
| 195 | 
         
            -
                    # 1) data processing
         
     | 
| 196 | 
         
            -
                    condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch)
         
     | 
| 197 | 
         
            -
             
     | 
| 198 | 
         
            -
                    # 2) compute model predictions
         
     | 
| 199 | 
         
            -
                    model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors)  
         
     | 
| 200 | 
         
            -
                    logits = model_output.logits
         
     | 
| 201 | 
         
            -
                    mask = padding_mask & model_output.mask
         
     | 
| 202 | 
         
            -
                    
         
     | 
| 203 | 
         
            -
                    # 3) compute loss and metrics
         
     | 
| 204 | 
         
            -
                    ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
         
     | 
| 205 | 
         
            -
                    metrics = {}   
         
     | 
| 206 | 
         
            -
                    metrics['val_ce'] = ce
         
     | 
| 207 | 
         
            -
                    metrics['val_ppl'] = torch.exp(ce)
         
     | 
| 208 | 
         
            -
                    for k, ce_q in enumerate(ce_per_codebook):
         
     | 
| 209 | 
         
            -
                        metrics[f'val_ce_q{k + 1}'] = ce_q
         
     | 
| 210 | 
         
            -
                        metrics[f'val_ppl_q{k + 1}'] = torch.exp(ce_q)
         
     | 
| 211 | 
         
            -
                    masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size)
         
     | 
| 212 | 
         
            -
             
     | 
| 213 | 
         
            -
                    for k in range(self.audiolm.code_depth):
         
     | 
| 214 | 
         
            -
                        self.top1_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) #* total_length
         
     | 
| 215 | 
         
            -
                        self.top10_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k])
         
     | 
| 216 | 
         
            -
                    self.val_steps.append(metrics)
         
     | 
| 217 | 
         
            -
                        
         
     | 
| 218 | 
         
            -
                    metrics['acc'] = []
         
     | 
| 219 | 
         
            -
                    metrics['acc_top10'] = []
         
     | 
| 220 | 
         
            -
                    for k in range(self.audiolm.code_depth):
         
     | 
| 221 | 
         
            -
                        metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item())
         
     | 
| 222 | 
         
            -
                        metrics['acc_top10'].append(self.top10_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item())
         
     | 
| 223 | 
         
            -
                    metrics['acc'] = torch.mean(torch.Tensor(metrics['acc']))
         
     | 
| 224 | 
         
            -
                    metrics['acc_top10'] = torch.mean(torch.Tensor(metrics['acc_top10'])) 
         
     | 
| 225 | 
         
            -
                    
         
     | 
| 226 | 
         
            -
                    return metrics['acc']
         
     | 
| 227 | 
         
            -
             
     | 
| 228 | 
         
            -
             
     | 
| 229 | 
         
            -
                def on_validation_epoch_end(self) -> None:        
         
     | 
| 230 | 
         
            -
                    final_metrics = {}
         
     | 
| 231 | 
         
            -
                    for i in self.val_steps:
         
     | 
| 232 | 
         
            -
                        for k in i:
         
     | 
| 233 | 
         
            -
                            final_metrics[k] = final_metrics.get(k, []) + [i[k]]
         
     | 
| 234 | 
         
            -
                    final_metrics = {k: sum(v) / len(v) for k,v in list(final_metrics.items())}
         
     | 
| 235 | 
         
            -
                    self.log_dict(final_metrics)
         
     | 
| 236 | 
         
            -
             
     | 
| 237 | 
         
            -
                    q_acc = []
         
     | 
| 238 | 
         
            -
                    q_acc10 = []
         
     | 
| 239 | 
         
            -
                    for i in range(self.audiolm.code_depth):
         
     | 
| 240 | 
         
            -
                        q_acc.append(self.top1_acc_metric[i].compute())
         
     | 
| 241 | 
         
            -
                        q_acc10.append(self.top10_acc_metric[i].compute())
         
     | 
| 242 | 
         
            -
                        self.log(f"val_Top1Acc_{i}", q_acc[-1])
         
     | 
| 243 | 
         
            -
                        self.log(f"val_Top10Acc_{i}", q_acc10[-1])
         
     | 
| 244 | 
         
            -
                        self.top1_acc_metric[i].reset()
         
     | 
| 245 | 
         
            -
                        self.top10_acc_metric[i].reset()
         
     | 
| 246 | 
         
            -
                    
         
     | 
| 247 | 
         
            -
                    self.log('val_Top1Acc', sum(q_acc) / self.audiolm.code_depth)
         
     | 
| 248 | 
         
            -
                    self.log('val_Top10Acc', sum(q_acc10) / self.audiolm.code_depth)
         
     | 
| 249 | 
         
            -
             
     | 
| 250 | 
         
            -
                    return super().on_validation_epoch_end()
         
     | 
| 251 | 
         
            -
             
     | 
| 252 | 
         
            -
             
     | 
| 253 | 
         
            -
                def on_validation_epoch_start(self) -> None:
         
     | 
| 254 | 
         
            -
                    self.val_steps = []
         
     | 
| 255 | 
         
            -
                    for i in range(self.audiolm.code_depth):
         
     | 
| 256 | 
         
            -
                        self.top1_acc_metric[i].reset()
         
     | 
| 257 | 
         
            -
                        self.top10_acc_metric[i].reset()
         
     | 
| 258 | 
         
            -
             
     | 
| 259 | 
         
            -
                    if len(self.train_steps) > 0:
         
     | 
| 260 | 
         
            -
                        train_metrics = {}
         
     | 
| 261 | 
         
            -
                        for i in self.train_steps:
         
     | 
| 262 | 
         
            -
                            for k in i:
         
     | 
| 263 | 
         
            -
                                train_metrics[k] = train_metrics.get(k, []) + [i[k]]
         
     | 
| 264 | 
         
            -
                        train_metrics = {k: sum(v) / len(v) for k,v in list(train_metrics.items())}
         
     | 
| 265 | 
         
            -
                        self.log('train_summary_Top1Acc', train_metrics['acc'])
         
     | 
| 266 | 
         
            -
                        self.log('train_summary_ce', train_metrics['ce'])
         
     | 
| 267 | 
         
            -
                        self.train_steps = []
         
     | 
| 268 | 
         
            -
             
     | 
| 269 | 
         
            -
                    return super().on_validation_epoch_start()
         
     | 
| 270 | 
         
            -
             
     | 
| 271 | 
         
            -
             
     | 
| 272 | 
         
            -
                # 定义优化器
         
     | 
| 273 | 
         
            -
                def configure_optimizers(self):
         
     | 
| 274 | 
         
            -
                    total_updates = self.cfg.optim.epochs * self.cfg.optim.updates_per_epoch
         
     | 
| 275 | 
         
            -
                    optim_dict = {}
         
     | 
| 276 | 
         
            -
             
     | 
| 277 | 
         
            -
                    param_groups = []
         
     | 
| 278 | 
         
            -
                    missing_params = []
         
     | 
| 279 | 
         
            -
                    other_params = []
         
     | 
| 280 | 
         
            -
                    cnt = 0
         
     | 
| 281 | 
         
            -
                    # 去掉开头的‘audiolm.'
         
     | 
| 282 | 
         
            -
                    print('before missing len', len(self.missing))
         
     | 
| 283 | 
         
            -
                    self.missing = [name.replace('audiolm.', '') for name in self.missing]
         
     | 
| 284 | 
         
            -
                    print('after missing len', len(self.missing))
         
     | 
| 285 | 
         
            -
                    for name, param in self.audiolm.named_parameters():
         
     | 
| 286 | 
         
            -
                        if name in self.missing:
         
     | 
| 287 | 
         
            -
                            cnt += 1
         
     | 
| 288 | 
         
            -
                            print(name)
         
     | 
| 289 | 
         
            -
                            missing_params.append(param)
         
     | 
| 290 | 
         
            -
                        else:
         
     | 
| 291 | 
         
            -
                            other_params.append(param)
         
     | 
| 292 | 
         
            -
                    print(cnt)
         
     | 
| 293 | 
         
            -
                    assert cnt == len(self.missing)
         
     | 
| 294 | 
         
            -
                    param_groups.append({'params': other_params, 'lr': self.cfg.optim.old_lr})
         
     | 
| 295 | 
         
            -
                    param_groups.append({
         
     | 
| 296 | 
         
            -
                        'params': missing_params,
         
     | 
| 297 | 
         
            -
                        'lr': self.cfg.optim.new_lr  # 为missing参数设置10倍的学习率,你可以调整这个倍数
         
     | 
| 298 | 
         
            -
                    })
         
     | 
| 299 | 
         
            -
             
     | 
| 300 | 
         
            -
                    if self.cfg.optim.optimizer == "adamw":
         
     | 
| 301 | 
         
            -
                        optim_dict['optimizer'] = torch.optim.AdamW(
         
     | 
| 302 | 
         
            -
                            param_groups,  # 使用参数分组替代原来的 self.audiolm.parameters()
         
     | 
| 303 | 
         
            -
                            betas=tuple(self.cfg.optim.adam.betas),
         
     | 
| 304 | 
         
            -
                            weight_decay=self.cfg.optim.adam.weight_decay,
         
     | 
| 305 | 
         
            -
                            eps=self.cfg.optim.adam.eps,
         
     | 
| 306 | 
         
            -
                        )
         
     | 
| 307 | 
         
            -
                    else:
         
     | 
| 308 | 
         
            -
                        raise NotImplementedError
         
     | 
| 309 | 
         
            -
             
     | 
| 310 | 
         
            -
                    if self.cfg.schedule is None:
         
     | 
| 311 | 
         
            -
                        pass
         
     | 
| 312 | 
         
            -
                    elif self.cfg.schedule.lr_scheduler == "cosine":
         
     | 
| 313 | 
         
            -
                        scheduler = CosineLRScheduler(optim_dict['optimizer'], 
         
     | 
| 314 | 
         
            -
                                                      total_steps=total_updates, 
         
     | 
| 315 | 
         
            -
                                                      warmup_steps=self.cfg.schedule.cosine.warmup,
         
     | 
| 316 | 
         
            -
                                                      lr_min_ratio=self.cfg.schedule.cosine.lr_min_ratio,
         
     | 
| 317 | 
         
            -
                                                      cycle_length=self.cfg.schedule.cosine.cycle_length,
         
     | 
| 318 | 
         
            -
                                                      )
         
     | 
| 319 | 
         
            -
                        optim_dict['lr_scheduler'] = {"scheduler": scheduler, "interval": "step"}
         
     | 
| 320 | 
         
            -
                    else:
         
     | 
| 321 | 
         
            -
                        raise NotImplementedError
         
     | 
| 322 | 
         
            -
                    
         
     | 
| 323 | 
         
            -
                    return optim_dict
         
     | 
| 324 | 
         
            -
             
     | 
| 325 | 
         
            -
                
         
     | 
| 326 | 
         
            -
                def _compute_cross_entropy(
         
     | 
| 327 | 
         
            -
                    self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
         
     | 
| 328 | 
         
            -
                ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
         
     | 
| 329 | 
         
            -
                    """Compute cross entropy between multi-codebook targets and model's logits.
         
     | 
| 330 | 
         
            -
                    The cross entropy is computed per codebook to provide codebook-level cross entropy.
         
     | 
| 331 | 
         
            -
                    Valid timesteps for each of the codebook are pulled from the mask, where invalid
         
     | 
| 332 | 
         
            -
                    timesteps are set to 0.
         
     | 
| 333 | 
         
            -
             
     | 
| 334 | 
         
            -
                    Args:
         
     | 
| 335 | 
         
            -
                        logits (torch.Tensor): Model's logits of shape [B, K, T, card].
         
     | 
| 336 | 
         
            -
                        targets (torch.Tensor): Target codes, of shape [B, K, T].
         
     | 
| 337 | 
         
            -
                        mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
         
     | 
| 338 | 
         
            -
                    Returns:
         
     | 
| 339 | 
         
            -
                        ce (torch.Tensor): Cross entropy averaged over the codebooks
         
     | 
| 340 | 
         
            -
                        ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
         
     | 
| 341 | 
         
            -
                    """
         
     | 
| 342 | 
         
            -
                    # import pdb; pdb.set_trace()
         
     | 
| 343 | 
         
            -
                    B, K, T = targets.shape
         
     | 
| 344 | 
         
            -
                    assert logits.shape[:-1] == targets.shape
         
     | 
| 345 | 
         
            -
                    assert mask.shape == targets.shape
         
     | 
| 346 | 
         
            -
                    ce = torch.zeros([], device=targets.device)
         
     | 
| 347 | 
         
            -
                    ce_per_codebook: tp.List[torch.Tensor] = []
         
     | 
| 348 | 
         
            -
                    for k in range(K):
         
     | 
| 349 | 
         
            -
                        logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1))  # [B x T, card]
         
     | 
| 350 | 
         
            -
                        targets_k = targets[:, k, ...].contiguous().view(-1)  # [B x T]
         
     | 
| 351 | 
         
            -
                        mask_k = mask[:, k, ...].contiguous().view(-1)  # [B x T]
         
     | 
| 352 | 
         
            -
                        ce_targets = targets_k[mask_k]
         
     | 
| 353 | 
         
            -
                        ce_logits = logits_k[mask_k]
         
     | 
| 354 | 
         
            -
                        q_ce = F.cross_entropy(ce_logits, ce_targets)
         
     | 
| 355 | 
         
            -
                        ce += q_ce
         
     | 
| 356 | 
         
            -
                        ce_per_codebook.append(q_ce.detach())
         
     | 
| 357 | 
         
            -
                    # average cross entropy across codebooks
         
     | 
| 358 | 
         
            -
                    ce = ce / K
         
     | 
| 359 | 
         
            -
                    return ce, ce_per_codebook
         
     | 
| 360 | 
         
            -
             
     | 
| 361 | 
         
            -
             
     | 
| 362 | 
         
            -
            class CodecLM_PL_FT(pl.LightningModule):
         
     | 
| 363 | 
         
            -
                def __init__(self, cfg):
         
     | 
| 364 | 
         
            -
                    super().__init__()
         
     | 
| 365 | 
         
            -
             
     | 
| 366 | 
         
            -
                    self.cfg = cfg
         
     | 
| 367 | 
         
            -
                    
         
     | 
| 368 | 
         
            -
                    # 1) Build audio tokenizer (usually None during training)
         
     | 
| 369 | 
         
            -
                    self.audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg)
         
     | 
| 370 | 
         
            -
                    if self.audio_tokenizer is not None:
         
     | 
| 371 | 
         
            -
                        for param in self.audio_tokenizer.parameters():
         
     | 
| 372 | 
         
            -
                            param.requires_grad = False
         
     | 
| 373 | 
         
            -
                    
         
     | 
| 374 | 
         
            -
                    # 2) Build LM
         
     | 
| 375 | 
         
            -
                    self.audiolm = builders.get_lm_model(self.cfg)
         
     | 
| 376 | 
         
            -
                    
         
     | 
| 377 | 
         
            -
                    # 3) Load pretrained checkpoint (if any)
         
     | 
| 378 | 
         
            -
                    if self.cfg.use_pretrained == 'deepspeed':
         
     | 
| 379 | 
         
            -
                        checkpoint = torch.load(self.cfg.pretrained.deepspeed_checkpoint,  map_location='cpu')
         
     | 
| 380 | 
         
            -
                        missing, unexpected = self.load_state_dict(checkpoint, strict=False)
         
     | 
| 381 | 
         
            -
                        print(f'-------------Missing--------------\n{missing}')
         
     | 
| 382 | 
         
            -
                        print(f'-------------Unexpected--------------\n{unexpected}')
         
     | 
| 383 | 
         
            -
                        print("successfully load deepspeed pretrained model {}".format(self.cfg.pretrained.deepspeed_checkpoint))
         
     | 
| 384 | 
         
            -
             
     | 
| 385 | 
         
            -
                    # 4) Build metrics
         
     | 
| 386 | 
         
            -
                    self.val_steps = []
         
     | 
| 387 | 
         
            -
                    self.train_slide_acc = []
         
     | 
| 388 | 
         
            -
                    self.train_steps = []
         
     | 
| 389 | 
         
            -
                    self.top1_acc_metric = nn.ModuleList([MulticlassAccuracy(
         
     | 
| 390 | 
         
            -
                        self.audiolm.code_size, 
         
     | 
| 391 | 
         
            -
                        top_k=1,
         
     | 
| 392 | 
         
            -
                        average="micro", multidim_average="global",
         
     | 
| 393 | 
         
            -
                        ignore_index=self.cfg.lm.code_size, # ignore EOS token prediction
         
     | 
| 394 | 
         
            -
                    ) for _ in range(self.audiolm.code_depth)])
         
     | 
| 395 | 
         
            -
                    self.top10_acc_metric = nn.ModuleList([MulticlassAccuracy(
         
     | 
| 396 | 
         
            -
                        self.audiolm.code_size,
         
     | 
| 397 | 
         
            -
                        top_k=10,
         
     | 
| 398 | 
         
            -
                        average="micro", multidim_average="global",
         
     | 
| 399 | 
         
            -
                        ignore_index=self.cfg.lm.code_size,
         
     | 
| 400 | 
         
            -
                    ) for _ in range(self.audiolm.code_depth)])
         
     | 
| 401 | 
         
            -
             
     | 
| 402 | 
         
            -
                    self.epoch = 0
         
     | 
| 403 | 
         
            -
                    print("++++++++++++++++ training <song> +++++++++++++++++")
         
     | 
| 404 | 
         
            -
             
     | 
| 405 | 
         
            -
                # TODO: move this part to loader
         
     | 
| 406 | 
         
            -
                def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384):
         
     | 
| 407 | 
         
            -
                    batch_size = sequence_lengths.size(0)
         
     | 
| 408 | 
         
            -
                    max_length = x.size(2)
         
     | 
| 409 | 
         
            -
             
     | 
| 410 | 
         
            -
                    # pad one frame, if the maximum sequence length is equal to the input length
         
     | 
| 411 | 
         
            -
                    if max_length == sequence_lengths.max():
         
     | 
| 412 | 
         
            -
                        x = F.pad(x, (0, 1), value=end_id)
         
     | 
| 413 | 
         
            -
                    max_length = x.size(2)
         
     | 
| 414 | 
         
            -
             
     | 
| 415 | 
         
            -
                    if max_length <= sequence_lengths.max() + 1:
         
     | 
| 416 | 
         
            -
                        sequence_lengths = sequence_lengths - (sequence_lengths.max()+1 - max_length)
         
     | 
| 417 | 
         
            -
             
     | 
| 418 | 
         
            -
                    # Add end token to x according to the sequence length
         
     | 
| 419 | 
         
            -
                    x[torch.arange(batch_size), :, sequence_lengths] = end_id
         
     | 
| 420 | 
         
            -
                    sequence_lengths += 1
         
     | 
| 421 | 
         
            -
             
     | 
| 422 | 
         
            -
                    mask = torch.arange(max_length).expand(batch_size, max_length) < sequence_lengths.unsqueeze(1)
         
     | 
| 423 | 
         
            -
                    mask = mask.to(x.device)
         
     | 
| 424 | 
         
            -
                    mask_3d = mask.unsqueeze(1).expand(batch_size, x.size(1), max_length)
         
     | 
| 425 | 
         
            -
                    x = torch.where(mask_3d, x, end_id+1)
         
     | 
| 426 | 
         
            -
                    return x, mask_3d
         
     | 
| 427 | 
         
            -
             
     | 
| 428 | 
         
            -
                @torch.no_grad()
         
     | 
| 429 | 
         
            -
                def preprocess_batch(self, batch):  # this function is usually called during training
         
     | 
| 430 | 
         
            -
                    # 处理 dataloader 返回的数据
         
     | 
| 431 | 
         
            -
                    audio, text_lyric, time_stamp, lang_type, prompt_audio = batch        
         
     | 
| 432 | 
         
            -
                    dur, valid_st, valid_et = zip(*time_stamp)
         
     | 
| 433 | 
         
            -
                    
         
     | 
| 434 | 
         
            -
                    if self.audio_tokenizer is not None:
         
     | 
| 435 | 
         
            -
                        # only used in inference
         
     | 
| 436 | 
         
            -
                        self.audio_tokenizer.eval()
         
     | 
| 437 | 
         
            -
                        with torch.no_grad():
         
     | 
| 438 | 
         
            -
                            with torch.cuda.amp.autocast(enabled=False):
         
     | 
| 439 | 
         
            -
                                audio_tokens, scale = self.audio_tokenizer.encode(audio)
         
     | 
| 440 | 
         
            -
                            audio_tokens = audio_tokens[:,:self.cfg.lm.code_depth,:]
         
     | 
| 441 | 
         
            -
                            audio_tokens = audio_tokens.long()
         
     | 
| 442 | 
         
            -
                    else:
         
     | 
| 443 | 
         
            -
                        audio_tokens = audio.long()
         
     | 
| 444 | 
         
            -
                    
         
     | 
| 445 | 
         
            -
                    token_dur = (torch.Tensor(dur) * self.cfg.audio_tokenizer_frame_rate).int()
         
     | 
| 446 | 
         
            -
             
     | 
| 447 | 
         
            -
                    audio_tokens, audio_padding_mask = self.generate_mask_and_end_token(audio_tokens, token_dur, 
         
     | 
| 448 | 
         
            -
                                                                                        end_id=self.audiolm.eos_token_id)
         
     | 
| 449 | 
         
            -
                    condition_tensors = self.audiolm.prepare_condition_tensors(batch_size=len(text_lyric),
         
     | 
| 450 | 
         
            -
                                                                               text=text_lyric, audio_qt_emb=prompt_audio)
         
     | 
| 451 | 
         
            -
             
     | 
| 452 | 
         
            -
                    return condition_tensors, audio_tokens, audio_padding_mask
         
     | 
| 453 | 
         
            -
             
     | 
| 454 | 
         
            -
                def get_time(self):
         
     | 
| 455 | 
         
            -
                    # 获取当前的日期和时间
         
     | 
| 456 | 
         
            -
                    now = datetime.now()
         
     | 
| 457 | 
         
            -
             
     | 
| 458 | 
         
            -
                    # 使用strftime函数格式化日期和时间
         
     | 
| 459 | 
         
            -
                    formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f")
         
     | 
| 460 | 
         
            -
                    return formatted_now
         
     | 
| 461 | 
         
            -
             
     | 
| 462 | 
         
            -
                def training_step(self, batch, batch_idx):
         
     | 
| 463 | 
         
            -
                    # 1) data processing
         
     | 
| 464 | 
         
            -
                    condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch)
         
     | 
| 465 | 
         
            -
                    
         
     | 
| 466 | 
         
            -
                    # 2) compute model predictions (model forward)
         
     | 
| 467 | 
         
            -
                    model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors, 
         
     | 
| 468 | 
         
            -
                                                                    training_steps=self.global_step)  # this input can be ignored        
         
     | 
| 469 | 
         
            -
                    logits = model_output.logits.float()
         
     | 
| 470 | 
         
            -
                    mask = padding_mask & model_output.mask
         
     | 
| 471 | 
         
            -
                    
         
     | 
| 472 | 
         
            -
                    # 3) compute loss (float)
         
     | 
| 473 | 
         
            -
                    with torch.cuda.amp.autocast(enabled=False):
         
     | 
| 474 | 
         
            -
                        ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
         
     | 
| 475 | 
         
            -
                    
         
     | 
| 476 | 
         
            -
                    total_loss = ce
         
     | 
| 477 | 
         
            -
                    if torch.isnan(total_loss):
         
     | 
| 478 | 
         
            -
                        print(self.trainer.global_rank, ce, padding_mask, batch[1])
         
     | 
| 479 | 
         
            -
                        # print('------------------------------------------------------------------------')
         
     | 
| 480 | 
         
            -
                        torchaudio.save("error_rank{}.wav".format(self.trainer.global_rank), batch[0][:,0].cpu(), 24000)
         
     | 
| 481 | 
         
            -
                        import pdb; pdb.set_trace()
         
     | 
| 482 | 
         
            -
                        return None
         
     | 
| 483 | 
         
            -
             
     | 
| 484 | 
         
            -
                    # 4) compute metrics and log
         
     | 
| 485 | 
         
            -
                    metrics = {}
         
     | 
| 486 | 
         
            -
                    self.log('ce', ce, prog_bar=True) 
         
     | 
| 487 | 
         
            -
                    metrics['ppl'] = torch.exp(ce)
         
     | 
| 488 | 
         
            -
                    for k, ce_q in enumerate(ce_per_codebook):
         
     | 
| 489 | 
         
            -
                        metrics[f'ce_q{k + 1}'] = ce_q
         
     | 
| 490 | 
         
            -
                        metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)
         
     | 
| 491 | 
         
            -
             
     | 
| 492 | 
         
            -
                    masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size)
         
     | 
| 493 | 
         
            -
                    metrics['acc'] = []
         
     | 
| 494 | 
         
            -
                    for k in range(self.audiolm.code_depth):
         
     | 
| 495 | 
         
            -
                        metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), 
         
     | 
| 496 | 
         
            -
                                                                      masked_labels[:, k]).item())
         
     | 
| 497 | 
         
            -
                    metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])).item()
         
     | 
| 498 | 
         
            -
             
     | 
| 499 | 
         
            -
                    self.train_steps.append({'ce': ce.detach().cpu().item(), 'acc': metrics['acc']})        
         
     | 
| 500 | 
         
            -
                    self.log('train_acc', metrics['acc']+1e-8, prog_bar=True)
         
     | 
| 501 | 
         
            -
                    self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], prog_bar=True)  
         
     | 
| 502 | 
         
            -
                    self.log_dict(metrics)
         
     | 
| 503 | 
         
            -
             
     | 
| 504 | 
         
            -
                    return total_loss
         
     | 
| 505 | 
         
            -
                
         
     | 
| 506 | 
         
            -
                @torch.no_grad()
         
     | 
| 507 | 
         
            -
                def validation_step(self, batch, batch_idx):
         
     | 
| 508 | 
         
            -
                    # 1) data processing
         
     | 
| 509 | 
         
            -
                    condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch)
         
     | 
| 510 | 
         
            -
             
     | 
| 511 | 
         
            -
                    # 2) compute model predictions
         
     | 
| 512 | 
         
            -
                    model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors)  
         
     | 
| 513 | 
         
            -
                    logits = model_output.logits
         
     | 
| 514 | 
         
            -
                    mask = padding_mask & model_output.mask
         
     | 
| 515 | 
         
            -
                    
         
     | 
| 516 | 
         
            -
                    # 3) compute loss and metrics
         
     | 
| 517 | 
         
            -
                    ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
         
     | 
| 518 | 
         
            -
                    metrics = {}   
         
     | 
| 519 | 
         
            -
                    metrics['val_ce'] = ce
         
     | 
| 520 | 
         
            -
                    metrics['val_ppl'] = torch.exp(ce)
         
     | 
| 521 | 
         
            -
                    for k, ce_q in enumerate(ce_per_codebook):
         
     | 
| 522 | 
         
            -
                        metrics[f'val_ce_q{k + 1}'] = ce_q
         
     | 
| 523 | 
         
            -
                        metrics[f'val_ppl_q{k + 1}'] = torch.exp(ce_q)
         
     | 
| 524 | 
         
            -
                    masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size)
         
     | 
| 525 | 
         
            -
             
     | 
| 526 | 
         
            -
                    for k in range(self.audiolm.code_depth):
         
     | 
| 527 | 
         
            -
                        self.top1_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) #* total_length
         
     | 
| 528 | 
         
            -
                        self.top10_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k])
         
     | 
| 529 | 
         
            -
                    self.val_steps.append(metrics)
         
     | 
| 530 | 
         
            -
                    metrics['acc'] = []
         
     | 
| 531 | 
         
            -
                    metrics['acc_top10'] = []
         
     | 
| 532 | 
         
            -
                    for k in range(self.audiolm.code_depth):
         
     | 
| 533 | 
         
            -
                        metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item())
         
     | 
| 534 | 
         
            -
                        metrics['acc_top10'].append(self.top10_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item())
         
     | 
| 535 | 
         
            -
                    metrics['acc'] = torch.mean(torch.Tensor(metrics['acc']))
         
     | 
| 536 | 
         
            -
                    metrics['acc_top10'] = torch.mean(torch.Tensor(metrics['acc_top10']))
         
     | 
| 537 | 
         
            -
                    
         
     | 
| 538 | 
         
            -
                    return metrics['acc']
         
     | 
| 539 | 
         
            -
             
     | 
| 540 | 
         
            -
                def on_validation_epoch_end(self) -> None:        
         
     | 
| 541 | 
         
            -
                    final_metrics = {}
         
     | 
| 542 | 
         
            -
                    for i in self.val_steps:
         
     | 
| 543 | 
         
            -
                        for k in i:
         
     | 
| 544 | 
         
            -
                            final_metrics[k] = final_metrics.get(k, []) + [i[k]]
         
     | 
| 545 | 
         
            -
                    final_metrics = {k: sum(v) / len(v) for k,v in list(final_metrics.items())}
         
     | 
| 546 | 
         
            -
                    self.log_dict(final_metrics)
         
     | 
| 547 | 
         
            -
             
     | 
| 548 | 
         
            -
                    q_acc = []
         
     | 
| 549 | 
         
            -
                    q_acc10 = []
         
     | 
| 550 | 
         
            -
                    for i in range(self.audiolm.code_depth):
         
     | 
| 551 | 
         
            -
                        q_acc.append(self.top1_acc_metric[i].compute())
         
     | 
| 552 | 
         
            -
                        q_acc10.append(self.top10_acc_metric[i].compute())
         
     | 
| 553 | 
         
            -
                        self.log(f"val_Top1Acc_{i}", q_acc[-1])
         
     | 
| 554 | 
         
            -
                        self.log(f"val_Top10Acc_{i}", q_acc10[-1])
         
     | 
| 555 | 
         
            -
                        self.top1_acc_metric[i].reset()
         
     | 
| 556 | 
         
            -
                        self.top10_acc_metric[i].reset()
         
     | 
| 557 | 
         
            -
                    
         
     | 
| 558 | 
         
            -
                    self.log('val_Top1Acc', sum(q_acc) / self.audiolm.code_depth)
         
     | 
| 559 | 
         
            -
                    self.log('val_Top10Acc', sum(q_acc10) / self.audiolm.code_depth)
         
     | 
| 560 | 
         
            -
             
     | 
| 561 | 
         
            -
                    return super().on_validation_epoch_end()
         
     | 
| 562 | 
         
            -
             
     | 
| 563 | 
         
            -
             
     | 
| 564 | 
         
            -
                def on_validation_epoch_start(self) -> None:
         
     | 
| 565 | 
         
            -
                    self.val_steps = []
         
     | 
| 566 | 
         
            -
                    for i in range(self.audiolm.code_depth):
         
     | 
| 567 | 
         
            -
                        self.top1_acc_metric[i].reset()
         
     | 
| 568 | 
         
            -
                        self.top10_acc_metric[i].reset()
         
     | 
| 569 | 
         
            -
             
     | 
| 570 | 
         
            -
                    if len(self.train_steps) > 0:
         
     | 
| 571 | 
         
            -
                        train_metrics = {}
         
     | 
| 572 | 
         
            -
                        for i in self.train_steps:
         
     | 
| 573 | 
         
            -
                            for k in i:
         
     | 
| 574 | 
         
            -
                                train_metrics[k] = train_metrics.get(k, []) + [i[k]]
         
     | 
| 575 | 
         
            -
                        train_metrics = {k: sum(v) / len(v) for k,v in list(train_metrics.items())}
         
     | 
| 576 | 
         
            -
                        self.log('train_summary_Top1Acc', train_metrics['acc'])
         
     | 
| 577 | 
         
            -
                        self.log('train_summary_ce', train_metrics['ce'])
         
     | 
| 578 | 
         
            -
                        self.train_steps = []
         
     | 
| 579 | 
         
            -
             
     | 
| 580 | 
         
            -
                    return super().on_validation_epoch_start()
         
     | 
| 581 | 
         
            -
             
     | 
| 582 | 
         
            -
             
     | 
| 583 | 
         
            -
                # 定义优化器
         
     | 
| 584 | 
         
            -
                def configure_optimizers(self):
         
     | 
| 585 | 
         
            -
                    total_updates = self.cfg.optim.epochs * self.cfg.optim.updates_per_epoch
         
     | 
| 586 | 
         
            -
                    optim_dict = {}
         
     | 
| 587 | 
         
            -
             
     | 
| 588 | 
         
            -
                    if self.cfg.optim.optimizer == "adamw":
         
     | 
| 589 | 
         
            -
                        optim_dict['optimizer'] = torch.optim.AdamW(
         
     | 
| 590 | 
         
            -
                            self.audiolm.parameters(),
         
     | 
| 591 | 
         
            -
                            lr=self.cfg.optim.lr,
         
     | 
| 592 | 
         
            -
                            betas=tuple(self.cfg.optim.adam.betas),
         
     | 
| 593 | 
         
            -
                            weight_decay=self.cfg.optim.adam.weight_decay,
         
     | 
| 594 | 
         
            -
                            eps=self.cfg.optim.adam.eps,
         
     | 
| 595 | 
         
            -
                        )
         
     | 
| 596 | 
         
            -
                    else:
         
     | 
| 597 | 
         
            -
                        raise NotImplementedError
         
     | 
| 598 | 
         
            -
             
     | 
| 599 | 
         
            -
                    if self.cfg.schedule is None:
         
     | 
| 600 | 
         
            -
                        pass
         
     | 
| 601 | 
         
            -
                    elif self.cfg.schedule.lr_scheduler == "cosine":
         
     | 
| 602 | 
         
            -
                        scheduler = CosineLRScheduler(optim_dict['optimizer'], 
         
     | 
| 603 | 
         
            -
                                                      total_steps=total_updates, 
         
     | 
| 604 | 
         
            -
                                                      warmup_steps=self.cfg.schedule.cosine.warmup,
         
     | 
| 605 | 
         
            -
                                                      lr_min_ratio=self.cfg.schedule.cosine.lr_min_ratio,
         
     | 
| 606 | 
         
            -
                                                      cycle_length=self.cfg.schedule.cosine.cycle_length,
         
     | 
| 607 | 
         
            -
                                                      )
         
     | 
| 608 | 
         
            -
                        optim_dict['lr_scheduler'] = {"scheduler": scheduler, "interval": "step"}
         
     | 
| 609 | 
         
            -
                    else:
         
     | 
| 610 | 
         
            -
                        raise NotImplementedError
         
     | 
| 611 | 
         
            -
                    
         
     | 
| 612 | 
         
            -
                    return optim_dict
         
     | 
| 613 | 
         
            -
             
     | 
| 614 | 
         
            -
                
         
     | 
| 615 | 
         
            -
                def _compute_cross_entropy(
         
     | 
| 616 | 
         
            -
                    self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
         
     | 
| 617 | 
         
            -
                ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
         
     | 
| 618 | 
         
            -
                    """Compute cross entropy between multi-codebook targets and model's logits.
         
     | 
| 619 | 
         
            -
                    The cross entropy is computed per codebook to provide codebook-level cross entropy.
         
     | 
| 620 | 
         
            -
                    Valid timesteps for each of the codebook are pulled from the mask, where invalid
         
     | 
| 621 | 
         
            -
                    timesteps are set to 0.
         
     | 
| 622 | 
         
            -
             
     | 
| 623 | 
         
            -
                    Args:
         
     | 
| 624 | 
         
            -
                        logits (torch.Tensor): Model's logits of shape [B, K, T, card].
         
     | 
| 625 | 
         
            -
                        targets (torch.Tensor): Target codes, of shape [B, K, T].
         
     | 
| 626 | 
         
            -
                        mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
         
     | 
| 627 | 
         
            -
                    Returns:
         
     | 
| 628 | 
         
            -
                        ce (torch.Tensor): Cross entropy averaged over the codebooks
         
     | 
| 629 | 
         
            -
                        ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
         
     | 
| 630 | 
         
            -
                    """
         
     | 
| 631 | 
         
            -
                    # import pdb; pdb.set_trace()
         
     | 
| 632 | 
         
            -
                    B, K, T = targets.shape
         
     | 
| 633 | 
         
            -
                    assert logits.shape[:-1] == targets.shape
         
     | 
| 634 | 
         
            -
                    assert mask.shape == targets.shape
         
     | 
| 635 | 
         
            -
                    ce = torch.zeros([], device=targets.device)
         
     | 
| 636 | 
         
            -
                    ce_per_codebook: tp.List[torch.Tensor] = []
         
     | 
| 637 | 
         
            -
                    for k in range(K):
         
     | 
| 638 | 
         
            -
                        logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1))  # [B x T, card]
         
     | 
| 639 | 
         
            -
                        targets_k = targets[:, k, ...].contiguous().view(-1)  # [B x T]
         
     | 
| 640 | 
         
            -
                        mask_k = mask[:, k, ...].contiguous().view(-1)  # [B x T]
         
     | 
| 641 | 
         
            -
                        ce_targets = targets_k[mask_k]
         
     | 
| 642 | 
         
            -
                        ce_logits = logits_k[mask_k]
         
     | 
| 643 | 
         
            -
                        q_ce = F.cross_entropy(ce_logits, ce_targets)
         
     | 
| 644 | 
         
            -
                        ce += q_ce
         
     | 
| 645 | 
         
            -
                        ce_per_codebook.append(q_ce.detach())
         
     | 
| 646 | 
         
            -
                    # average cross entropy across codebooks
         
     | 
| 647 | 
         
            -
                    ce = ce / K
         
     | 
| 648 | 
         
            -
                    return ce, ce_per_codebook
         
     | 
| 649 | 
         
            -
             
     | 
| 650 | 
         
             
            class CosineLRScheduler(_LRScheduler):# 
         
     | 
| 651 | 
         
             
                """Cosine LR scheduler.
         
     | 
| 652 | 
         | 
| 
         | 
|
| 26 | 
         | 
| 27 | 
         | 
| 28 | 
         
             
            class CodecLM_PL(pl.LightningModule):
         
     | 
| 29 | 
         
            +
                def __init__(self, cfg, ckpt_path):
         
     | 
| 30 | 
         
             
                    super().__init__()
         
     | 
| 31 | 
         | 
| 32 | 
         
             
                    self.cfg = cfg
         
     | 
| 
         | 
|
| 46 | 
         
             
                    # 2) Build LM
         
     | 
| 47 | 
         
             
                    self.audiolm = builders.get_lm_model(self.cfg)
         
     | 
| 48 | 
         
             
                    print(self.audiolm)
         
     | 
| 
         | 
|
| 
         | 
|
| 49 | 
         
             
                    # 3) Load pretrained checkpoint (if any)
         
     | 
| 50 | 
         
            +
                    checkpoint = torch.load(ckpt_path, map_location='cpu')
         
     | 
| 51 | 
         
            +
                    missing, unexpected = self.load_state_dict(checkpoint, strict=False)
         
     | 
| 52 | 
         
            +
                    print(f'-------------Missing--------------\n{missing}')
         
     | 
| 53 | 
         
            +
                    print(f'-------------Unexpected--------------\n{unexpected}')
         
     | 
| 54 | 
         
            +
                    print("successfully load deepspeed pretrained model {}".format(ckpt_path))
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 55 | 
         
             
                    # 4) Build metrics
         
     | 
| 56 | 
         
             
                    self.val_steps = []
         
     | 
| 57 | 
         
             
                    self.train_slide_acc = []
         
     | 
| 
         | 
|
| 95 | 
         
             
                    x = torch.where(mask_3d, x, end_id+1)
         
     | 
| 96 | 
         
             
                    return x, mask_3d
         
     | 
| 97 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 98 | 
         
             
                def get_time(self):
         
     | 
| 99 | 
         
             
                    # 获取当前的日期和时间
         
     | 
| 100 | 
         
             
                    now = datetime.now()
         
     | 
| 
         | 
|
| 103 | 
         
             
                    formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f")
         
     | 
| 104 | 
         
             
                    return formatted_now
         
     | 
| 105 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 106 | 
         
             
            class CosineLRScheduler(_LRScheduler):# 
         
     | 
| 107 | 
         
             
                """Cosine LR scheduler.
         
     | 
| 108 | 
         | 
    	
        conf/infer.yaml
    DELETED
    
    | 
         @@ -1,152 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            # ================ Logging ====================== #
         
     | 
| 2 | 
         
            -
            root_dir: exp/song/${get_fname:}
         
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
            # ================ Checkpoints ================== #
         
     | 
| 5 | 
         
            -
            use_pretrained: deepspeed # ['ddp', 'continue', 'deepspeed']
         
     | 
| 6 | 
         
            -
            pretrained:
         
     | 
| 7 | 
         
            -
              ddp_checkpoint:
         
     | 
| 8 | 
         
            -
              deepspeed_checkpoint: ./ckpt/60000_alnew.pt
         
     | 
| 9 | 
         
            -
              continue_checkpoint: 
         
     | 
| 10 | 
         
            -
              
         
     | 
| 11 | 
         
            -
            # ================ Data & loader ================== #
         
     | 
| 12 | 
         
            -
            prompt_select: random
         
     | 
| 13 | 
         
            -
            train_jsonl_list:
         
     | 
| 14 | 
         
            -
            - .jsonl
         
     | 
| 15 | 
         
            -
            val_jsonl_list:
         
     | 
| 16 | 
         
            -
            - .jsonl
         
     | 
| 17 | 
         
            -
            train_scp_list:
         
     | 
| 18 | 
         
            -
            - .scp
         
     | 
| 19 | 
         
            -
            val_scp_list:
         
     | 
| 20 | 
         
            -
            - .scp
         
     | 
| 21 | 
         
            -
             
     | 
| 22 | 
         
            -
            lyric_processor:
         
     | 
| 23 | 
         
            -
            max_dur: 150
         
     | 
| 24 | 
         
            -
            min_dur: 30
         
     | 
| 25 | 
         
            -
            batch_size: 2
         
     | 
| 26 | 
         
            -
            prompt_len: 10
         
     | 
| 27 | 
         
            -
            pad_to_max: true
         
     | 
| 28 | 
         
            -
             
     | 
| 29 | 
         
            -
            # ================ Training ======================= #
         
     | 
| 30 | 
         
            -
            accelerator: gpu
         
     | 
| 31 | 
         
            -
            devices: 8
         
     | 
| 32 | 
         
            -
            num_nodes: 4
         
     | 
| 33 | 
         
            -
            val_check_interval: 2500
         
     | 
| 34 | 
         
            -
            accumulate_grad_batches: 1
         
     | 
| 35 | 
         
            -
            strategy: 'deepspeed_stage_2' # ['ddp', 'fsdp', 'deepspeed_stage_2', 'ddp_find_unused_parameters_true']
         
     | 
| 36 | 
         
            -
            precision: 'bf16-mixed' # ['16-mixed', 'bf16-mixed']
         
     | 
| 37 | 
         
            -
             
     | 
| 38 | 
         
            -
            optim:
         
     | 
| 39 | 
         
            -
              optimizer: adamw
         
     | 
| 40 | 
         
            -
              updates_per_epoch: 1000
         
     | 
| 41 | 
         
            -
              epochs: 100
         
     | 
| 42 | 
         
            -
              old_lr: 0 # 1e-4
         
     | 
| 43 | 
         
            -
              new_lr: 1e-4
         
     | 
| 44 | 
         
            -
              max_norm: 0.5
         
     | 
| 45 | 
         
            -
              adam:
         
     | 
| 46 | 
         
            -
                betas:
         
     | 
| 47 | 
         
            -
                - 0.9
         
     | 
| 48 | 
         
            -
                - 0.95
         
     | 
| 49 | 
         
            -
                weight_decay: 0.00001 # 0.1
         
     | 
| 50 | 
         
            -
                eps: 1e-8
         
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
            schedule:
         
     | 
| 53 | 
         
            -
              lr_scheduler: cosine
         
     | 
| 54 | 
         
            -
              cosine:
         
     | 
| 55 | 
         
            -
                warmup: 4000
         
     | 
| 56 | 
         
            -
                lr_min_ratio: 0.0
         
     | 
| 57 | 
         
            -
                cycle_length: 1.0
         
     | 
| 58 | 
         
            -
             
     | 
| 59 | 
         
            -
            # ================ Audio tokenzier ================ #
         
     | 
| 60 | 
         
            -
            audio_tokenizer_checkpoint: Flow1dVAE1rvq_./ckpt/model_1rvq/model_2_fixed.safetensors
         
     | 
| 61 | 
         
            -
            audio_tokenizer_frame_rate: 25
         
     | 
| 62 | 
         
            -
            audio_tokenizer_code_depth: 1
         
     | 
| 63 | 
         
            -
            sample_rate: 48000
         
     | 
| 64 | 
         
            -
             
     | 
| 65 | 
         
            -
            audio_tokenizer_checkpoint_sep: Flow1dVAESeparate_./ckpt/model_septoken/model_2.safetensors
         
     | 
| 66 | 
         
            -
            audio_tokenizer_frame_rate_sep: 25
         
     | 
| 67 | 
         
            -
            audio_tokenizer_code_depth_sep: 2
         
     | 
| 68 | 
         
            -
            sample_rate_sep: 48000
         
     | 
| 69 | 
         
            -
             
     | 
| 70 | 
         
            -
            # ================ VAE ================ #
         
     | 
| 71 | 
         
            -
            vae_config: ./ckpt/vae/stable_audio_1920_vae.json
         
     | 
| 72 | 
         
            -
            vae_model: ./ckpt/vae/autoencoder_music_1320k.ckpt
         
     | 
| 73 | 
         
            -
             
     | 
| 74 | 
         
            -
            # ================== LM =========================== #
         
     | 
| 75 | 
         
            -
            lm:
         
     | 
| 76 | 
         
            -
              lm_type: Llama # [Llama]
         
     | 
| 77 | 
         
            -
              dim: 1536
         
     | 
| 78 | 
         
            -
              intermediate_size: 8960
         
     | 
| 79 | 
         
            -
              num_heads: 12
         
     | 
| 80 | 
         
            -
              num_layers: 28
         
     | 
| 81 | 
         
            -
              code_depth: 3
         
     | 
| 82 | 
         
            -
              code_size: 16384
         
     | 
| 83 | 
         
            -
              dropout: 0.0
         
     | 
| 84 | 
         
            -
              activation: gelu
         
     | 
| 85 | 
         
            -
              norm_first: true
         
     | 
| 86 | 
         
            -
              bias_ff: false
         
     | 
| 87 | 
         
            -
              bias_attn: false
         
     | 
| 88 | 
         
            -
              bias_proj: false
         
     | 
| 89 | 
         
            -
              causal: true
         
     | 
| 90 | 
         
            -
              custom: false
         
     | 
| 91 | 
         
            -
              memory_efficient: true
         
     | 
| 92 | 
         
            -
              attention_as_float32: false
         
     | 
| 93 | 
         
            -
              layer_scale: null
         
     | 
| 94 | 
         
            -
              positional_embedding: sin
         
     | 
| 95 | 
         
            -
              xpos: false
         
     | 
| 96 | 
         
            -
              checkpointing: torch
         
     | 
| 97 | 
         
            -
              weight_init: gaussian
         
     | 
| 98 | 
         
            -
              depthwise_init: current
         
     | 
| 99 | 
         
            -
              zero_bias_init: true
         
     | 
| 100 | 
         
            -
              norm: layer_norm
         
     | 
| 101 | 
         
            -
              cross_attention: false
         
     | 
| 102 | 
         
            -
              qk_layer_norm: false
         
     | 
| 103 | 
         
            -
              qk_layer_norm_cross: false
         
     | 
| 104 | 
         
            -
              attention_dropout: null
         
     | 
| 105 | 
         
            -
              kv_repeat: 1
         
     | 
| 106 | 
         
            -
             
     | 
| 107 | 
         
            -
            codebooks_pattern:
         
     | 
| 108 | 
         
            -
              modeling: delay
         
     | 
| 109 | 
         
            -
              delay:
         
     | 
| 110 | 
         
            -
                delays: [ 0, 250, 250 ]
         
     | 
| 111 | 
         
            -
                flatten_first: 0
         
     | 
| 112 | 
         
            -
                empty_initial: 0
         
     | 
| 113 | 
         
            -
             
     | 
| 114 | 
         
            -
            # ================ Conditioners ===================== #
         
     | 
| 115 | 
         
            -
            classifier_free_guidance:
         
     | 
| 116 | 
         
            -
              # drop all conditions simultaneously
         
     | 
| 117 | 
         
            -
              training_dropout: 0.15
         
     | 
| 118 | 
         
            -
              inference_coef: 1.5
         
     | 
| 119 | 
         
            -
             
     | 
| 120 | 
         
            -
            attribute_dropout:
         
     | 
| 121 | 
         
            -
              # drop each condition separately
         
     | 
| 122 | 
         
            -
              args:
         
     | 
| 123 | 
         
            -
                active_on_eval: false
         
     | 
| 124 | 
         
            -
              text:
         
     | 
| 125 | 
         
            -
                description: 0.0
         
     | 
| 126 | 
         
            -
                type_info: 0.5
         
     | 
| 127 | 
         
            -
              audio:
         
     | 
| 128 | 
         
            -
                prompt_audio: 0.0
         
     | 
| 129 | 
         
            -
             
     | 
| 130 | 
         
            -
            use_text_training: True
         
     | 
| 131 | 
         
            -
            fuser:
         
     | 
| 132 | 
         
            -
              sum: []
         
     | 
| 133 | 
         
            -
              prepend: [ description, prompt_audio, type_info ] # this order is the SAME with the input concatenation order
         
     | 
| 134 | 
         
            -
             
     | 
| 135 | 
         
            -
            conditioners:
         
     | 
| 136 | 
         
            -
              prompt_audio:
         
     | 
| 137 | 
         
            -
                model: qt_embedding
         
     | 
| 138 | 
         
            -
                qt_embedding:
         
     | 
| 139 | 
         
            -
                  code_size: 16384
         
     | 
| 140 | 
         
            -
                  code_depth: 3
         
     | 
| 141 | 
         
            -
                  max_len: ${eval:${prompt_len}*${audio_tokenizer_frame_rate}+2} # 25*10+2+1
         
     | 
| 142 | 
         
            -
              description:
         
     | 
| 143 | 
         
            -
                model: QwTokenizer
         
     | 
| 144 | 
         
            -
                QwTokenizer:
         
     | 
| 145 | 
         
            -
                  token_path: third_party/Qwen2-7B
         
     | 
| 146 | 
         
            -
                  max_len: 300
         
     | 
| 147 | 
         
            -
                  add_token_list: ${load_yaml:conf/vocab.yaml}
         
     | 
| 148 | 
         
            -
              type_info:
         
     | 
| 149 | 
         
            -
                model: QwTextTokenizer
         
     | 
| 150 | 
         
            -
                QwTextTokenizer:
         
     | 
| 151 | 
         
            -
                  token_path: third_party/Qwen2-7B
         
     | 
| 152 | 
         
            -
                  max_len: 50
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        generate.py
    CHANGED
    
    | 
         @@ -12,6 +12,7 @@ from codeclm.trainer.codec_song_pl import CodecLM_PL 
     | 
|
| 12 | 
         
             
            from codeclm.models import CodecLM
         
     | 
| 13 | 
         
             
            from third_party.demucs.models.pretrained import get_model_from_yaml
         
     | 
| 14 | 
         | 
| 
         | 
|
| 15 | 
         | 
| 16 | 
         
             
            class Separator:
         
     | 
| 17 | 
         
             
                def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
         
     | 
| 
         @@ -58,21 +59,25 @@ class Separator: 
     | 
|
| 58 | 
         
             
                    return full_audio, vocal_audio, bgm_audio
         
     | 
| 59 | 
         | 
| 60 | 
         | 
| 61 | 
         
            -
             
     | 
| 62 | 
         
            -
             
     | 
| 
         | 
|
| 63 | 
         
             
                OmegaConf.register_new_resolver("eval", lambda x: eval(x))
         
     | 
| 64 | 
         
             
                OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
         
     | 
| 65 | 
         
             
                OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
         
     | 
| 66 | 
         
             
                OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
         
     | 
| 67 | 
         
            -
                 
     | 
| 68 | 
         
            -
                 
     | 
| 69 | 
         
            -
                input_jsonl = sys.argv[ 
     | 
| 70 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 71 | 
         
             
                cfg.mode = 'inference'
         
     | 
| 72 | 
         
             
                max_duration = cfg.max_dur
         
     | 
| 73 | 
         | 
| 74 | 
         
             
                # Define model or load pretrained model
         
     | 
| 75 | 
         
            -
                model_light = CodecLM_PL(cfg)
         
     | 
| 76 | 
         | 
| 77 | 
         
             
                model_light = model_light.eval().cuda()
         
     | 
| 78 | 
         
             
                model_light.audiolm.cfg = cfg
         
     | 
| 
         @@ -83,9 +88,10 @@ def main_sep(): 
     | 
|
| 83 | 
         
             
                    seperate_tokenizer = model_light.seperate_tokenizer,
         
     | 
| 84 | 
         
             
                )
         
     | 
| 85 | 
         
             
                separator = Separator()
         
     | 
| 86 | 
         
            -
                
         
     | 
| 
         | 
|
| 87 | 
         
             
                cfg_coef = 1.5 #25
         
     | 
| 88 | 
         
            -
                temp =  
     | 
| 89 | 
         
             
                top_k = 50
         
     | 
| 90 | 
         
             
                top_p = 0.0
         
     | 
| 91 | 
         
             
                record_tokens = True
         
     | 
| 
         @@ -93,7 +99,7 @@ def main_sep(): 
     | 
|
| 93 | 
         | 
| 94 | 
         
             
                model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
         
     | 
| 95 | 
         
             
                                            top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
         
     | 
| 96 | 
         
            -
                os.makedirs(save_dir 
     | 
| 97 | 
         
             
                os.makedirs(save_dir + "/audios", exist_ok=True)
         
     | 
| 98 | 
         
             
                os.makedirs(save_dir + "/jsonl", exist_ok=True)
         
     | 
| 99 | 
         | 
| 
         @@ -103,43 +109,58 @@ def main_sep(): 
     | 
|
| 103 | 
         
             
                new_items = []
         
     | 
| 104 | 
         
             
                for line in lines:
         
     | 
| 105 | 
         
             
                    item = json.loads(line)
         
     | 
| 106 | 
         
            -
                     
     | 
| 107 | 
         
            -
                    target_wav_name = f"{save_dir}/audios/{item['idx']}_s{sidx}.flac"
         
     | 
| 108 | 
         
            -
                    descriptions = item["descriptions"]
         
     | 
| 109 | 
         
             
                    lyric = item["gt_lyric"]
         
     | 
| 110 | 
         
            -
                    
         
     | 
| 111 | 
         
            -
                     
     | 
| 112 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 113 | 
         
             
                    generate_inp = {
         
     | 
| 114 | 
         
             
                        'lyrics': [lyric.replace("  ", " ")],
         
     | 
| 115 | 
         
             
                        'descriptions': [descriptions],
         
     | 
| 116 | 
         
             
                        'melody_wavs': pmt_wav,
         
     | 
| 117 | 
         
             
                        'vocal_wavs': vocal_wav,
         
     | 
| 118 | 
         
             
                        'bgm_wavs': bgm_wav,
         
     | 
| 
         | 
|
| 119 | 
         
             
                    }
         
     | 
| 120 | 
         
            -
             
     | 
| 121 | 
         
            -
                    mid_time = time.time()
         
     | 
| 122 | 
         
             
                    with torch.autocast(device_type="cuda", dtype=torch.float16):
         
     | 
| 123 | 
         
             
                        tokens = model.generate(**generate_inp, return_tokens=True)
         
     | 
| 124 | 
         
            -
                     
     | 
| 125 | 
         
            -
                    if tokens.shape[-1] > 3000:
         
     | 
| 126 | 
         
            -
                        tokens = tokens[..., :3000]
         
     | 
| 127 | 
         | 
| 128 | 
         
             
                    with torch.no_grad():
         
     | 
| 129 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 130 | 
         
             
                    torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
         
     | 
| 131 | 
         
            -
                     
     | 
| 132 | 
         
            -
                    print(f"process{item['idx']}, demucs cost {mid_time - start_time}s, lm cos {end_time - mid_time}")
         
     | 
| 133 | 
         | 
| 134 | 
         
            -
                    item["idx"] = f"{item['idx']} 
     | 
| 135 | 
         
            -
                    item[" 
     | 
| 136 | 
         
             
                    new_items.append(item)
         
     | 
| 137 | 
         | 
| 138 | 
         
             
                src_jsonl_name = os.path.split(input_jsonl)[-1]
         
     | 
| 139 | 
         
            -
                with open(f"{save_dir}/jsonl/{src_jsonl_name} 
     | 
| 140 | 
         
             
                    for item in new_items:
         
     | 
| 141 | 
         
             
                        fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
         
     | 
| 142 | 
         
            -
             
     | 
| 143 | 
         
            -
             
     | 
| 144 | 
         
            -
            if __name__ == "__main__":
         
     | 
| 145 | 
         
            -
                main_sep()
         
     | 
| 
         | 
|
| 12 | 
         
             
            from codeclm.models import CodecLM
         
     | 
| 13 | 
         
             
            from third_party.demucs.models.pretrained import get_model_from_yaml
         
     | 
| 14 | 
         | 
| 15 | 
         
            +
            auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto']
         
     | 
| 16 | 
         | 
| 17 | 
         
             
            class Separator:
         
     | 
| 18 | 
         
             
                def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
         
     | 
| 
         | 
|
| 59 | 
         
             
                    return full_audio, vocal_audio, bgm_audio
         
     | 
| 60 | 
         | 
| 61 | 
         | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 64 | 
         
            +
                torch.backends.cudnn.enabled = False
         
     | 
| 65 | 
         
             
                OmegaConf.register_new_resolver("eval", lambda x: eval(x))
         
     | 
| 66 | 
         
             
                OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
         
     | 
| 67 | 
         
             
                OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
         
     | 
| 68 | 
         
             
                OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
         
     | 
| 69 | 
         
            +
                np.random.seed(int(time.time()))    
         
     | 
| 70 | 
         
            +
                ckpt_path = sys.argv[1]
         
     | 
| 71 | 
         
            +
                input_jsonl = sys.argv[2]
         
     | 
| 72 | 
         
            +
                save_dir = sys.argv[3]
         
     | 
| 73 | 
         
            +
                cfg_path = os.path.join(ckpt_path, 'config.yaml')
         
     | 
| 74 | 
         
            +
                ckpt_path = os.path.join(ckpt_path, 'model.pt')
         
     | 
| 75 | 
         
            +
                cfg = OmegaConf.load(cfg_path)
         
     | 
| 76 | 
         
             
                cfg.mode = 'inference'
         
     | 
| 77 | 
         
             
                max_duration = cfg.max_dur
         
     | 
| 78 | 
         | 
| 79 | 
         
             
                # Define model or load pretrained model
         
     | 
| 80 | 
         
            +
                model_light = CodecLM_PL(cfg, ckpt_path)
         
     | 
| 81 | 
         | 
| 82 | 
         
             
                model_light = model_light.eval().cuda()
         
     | 
| 83 | 
         
             
                model_light.audiolm.cfg = cfg
         
     | 
| 
         | 
|
| 88 | 
         
             
                    seperate_tokenizer = model_light.seperate_tokenizer,
         
     | 
| 89 | 
         
             
                )
         
     | 
| 90 | 
         
             
                separator = Separator()
         
     | 
| 91 | 
         
            +
                auto_prompt = torch.load('ckpt/prompt.pt')
         
     | 
| 92 | 
         
            +
                merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
         
     | 
| 93 | 
         
             
                cfg_coef = 1.5 #25
         
     | 
| 94 | 
         
            +
                temp = 0.9
         
     | 
| 95 | 
         
             
                top_k = 50
         
     | 
| 96 | 
         
             
                top_p = 0.0
         
     | 
| 97 | 
         
             
                record_tokens = True
         
     | 
| 
         | 
|
| 99 | 
         | 
| 100 | 
         
             
                model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
         
     | 
| 101 | 
         
             
                                            top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
         
     | 
| 102 | 
         
            +
                os.makedirs(save_dir, exist_ok=True)
         
     | 
| 103 | 
         
             
                os.makedirs(save_dir + "/audios", exist_ok=True)
         
     | 
| 104 | 
         
             
                os.makedirs(save_dir + "/jsonl", exist_ok=True)
         
     | 
| 105 | 
         | 
| 
         | 
|
| 109 | 
         
             
                new_items = []
         
     | 
| 110 | 
         
             
                for line in lines:
         
     | 
| 111 | 
         
             
                    item = json.loads(line)
         
     | 
| 112 | 
         
            +
                    target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
         
     | 
| 
         | 
|
| 
         | 
|
| 113 | 
         
             
                    lyric = item["gt_lyric"]
         
     | 
| 114 | 
         
            +
                    descriptions = item["descriptions"] if "descriptions" in item else None
         
     | 
| 115 | 
         
            +
                    # get prompt audio
         
     | 
| 116 | 
         
            +
                    if "prompt_audio_path" in item:
         
     | 
| 117 | 
         
            +
                        assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
         
     | 
| 118 | 
         
            +
                        assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
         
     | 
| 119 | 
         
            +
                        pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
         
     | 
| 120 | 
         
            +
                        melody_is_wav = True
         
     | 
| 121 | 
         
            +
                    elif "auto_prompt_audio_type" in item:
         
     | 
| 122 | 
         
            +
                        assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
         
     | 
| 123 | 
         
            +
                        if item["auto_prompt_audio_type"] == "Auto": 
         
     | 
| 124 | 
         
            +
                            prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
         
     | 
| 125 | 
         
            +
                        else:
         
     | 
| 126 | 
         
            +
                            prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
         
     | 
| 127 | 
         
            +
                        pmt_wav = prompt_token[:,[0],:]
         
     | 
| 128 | 
         
            +
                        vocal_wav = prompt_token[:,[1],:]
         
     | 
| 129 | 
         
            +
                        bgm_wav = prompt_token[:,[2],:]
         
     | 
| 130 | 
         
            +
                        melody_is_wav = False
         
     | 
| 131 | 
         
            +
                    else:
         
     | 
| 132 | 
         
            +
                        pmt_wav = None
         
     | 
| 133 | 
         
            +
                        vocal_wav = None
         
     | 
| 134 | 
         
            +
                        bgm_wav = None
         
     | 
| 135 | 
         
            +
                        melody_is_wav = True
         
     | 
| 136 | 
         
            +
                        
         
     | 
| 137 | 
         
             
                    generate_inp = {
         
     | 
| 138 | 
         
             
                        'lyrics': [lyric.replace("  ", " ")],
         
     | 
| 139 | 
         
             
                        'descriptions': [descriptions],
         
     | 
| 140 | 
         
             
                        'melody_wavs': pmt_wav,
         
     | 
| 141 | 
         
             
                        'vocal_wavs': vocal_wav,
         
     | 
| 142 | 
         
             
                        'bgm_wavs': bgm_wav,
         
     | 
| 143 | 
         
            +
                        'melody_is_wav': melody_is_wav,
         
     | 
| 144 | 
         
             
                    }
         
     | 
| 145 | 
         
            +
                    start_time = time.time()
         
     | 
| 
         | 
|
| 146 | 
         
             
                    with torch.autocast(device_type="cuda", dtype=torch.float16):
         
     | 
| 147 | 
         
             
                        tokens = model.generate(**generate_inp, return_tokens=True)
         
     | 
| 148 | 
         
            +
                    mid_time = time.time()
         
     | 
| 
         | 
|
| 
         | 
|
| 149 | 
         | 
| 150 | 
         
             
                    with torch.no_grad():
         
     | 
| 151 | 
         
            +
                        if melody_is_wav:   
         
     | 
| 152 | 
         
            +
                            wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)
         
     | 
| 153 | 
         
            +
                        else:
         
     | 
| 154 | 
         
            +
                            wav_seperate = model.generate_audio(tokens)
         
     | 
| 155 | 
         
            +
                    end_time = time.time()
         
     | 
| 156 | 
         
             
                    torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
         
     | 
| 157 | 
         
            +
                    print(f"process{item['idx']}, lm cost {mid_time - start_time}s, diffusion cost {end_time - mid_time}")
         
     | 
| 
         | 
|
| 158 | 
         | 
| 159 | 
         
            +
                    item["idx"] = f"{item['idx']}"
         
     | 
| 160 | 
         
            +
                    item["wav_path"] = target_wav_name
         
     | 
| 161 | 
         
             
                    new_items.append(item)
         
     | 
| 162 | 
         | 
| 163 | 
         
             
                src_jsonl_name = os.path.split(input_jsonl)[-1]
         
     | 
| 164 | 
         
            +
                with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
         
     | 
| 165 | 
         
             
                    for item in new_items:
         
     | 
| 166 | 
         
             
                        fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        generate.sh
    CHANGED
    
    | 
         @@ -4,9 +4,7 @@ export TRANSFORMERS_CACHE="$(pwd)/third_party/hub" 
     | 
|
| 4 | 
         
             
            export NCCL_HOME=/usr/local/tccl
         
     | 
| 5 | 
         
             
            export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
         
     | 
| 6 | 
         | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
            JSONL=$ 
     | 
| 9 | 
         
            -
            SAVE_DIR=$ 
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
            DEVICE=0
         
     | 
| 12 | 
         
            -
            OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=$DEVICE python3 generate.py $CFG_FILE $SAVE_DIR $JSONL $SIDX
         
     | 
| 
         | 
|
| 4 | 
         
             
            export NCCL_HOME=/usr/local/tccl
         
     | 
| 5 | 
         
             
            export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
         
     | 
| 6 | 
         | 
| 7 | 
         
            +
            CKPT_PATH=$1
         
     | 
| 8 | 
         
            +
            JSONL=$2
         
     | 
| 9 | 
         
            +
            SAVE_DIR=$3
         
     | 
| 10 | 
         
            +
            python3 generate.py $CKPT_PATH $JSONL $SAVE_DIR
         
     | 
| 
         | 
|
| 
         | 
    	
        levo_inference.py
    CHANGED
    
    | 
         @@ -18,7 +18,7 @@ from separator import Separator 
     | 
|
| 18 | 
         | 
| 19 | 
         | 
| 20 | 
         
             
            class LeVoInference(torch.nn.Module):
         
     | 
| 21 | 
         
            -
                def __init__(self,  
     | 
| 22 | 
         
             
                    super().__init__()
         
     | 
| 23 | 
         | 
| 24 | 
         
             
                    torch.backends.cudnn.enabled = False 
         
     | 
| 
         @@ -27,12 +27,15 @@ class LeVoInference(torch.nn.Module): 
     | 
|
| 27 | 
         
             
                    OmegaConf.register_new_resolver("get_fname", lambda: 'default')
         
     | 
| 28 | 
         
             
                    OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
         
     | 
| 29 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 30 | 
         
             
                    self.cfg = OmegaConf.load(cfg_path)
         
     | 
| 31 | 
         
             
                    self.cfg.mode = 'inference'
         
     | 
| 32 | 
         
             
                    self.max_duration = self.cfg.max_dur
         
     | 
| 33 | 
         | 
| 34 | 
         
             
                    # Define model or load pretrained model
         
     | 
| 35 | 
         
            -
                    model_light = CodecLM_PL(self.cfg)
         
     | 
| 36 | 
         | 
| 37 | 
         
             
                    model_light = model_light.eval().cuda()
         
     | 
| 38 | 
         
             
                    model_light.audiolm.cfg = self.cfg
         
     | 
| 
         @@ -63,15 +66,28 @@ class LeVoInference(torch.nn.Module): 
     | 
|
| 63 | 
         | 
| 64 | 
         
             
                    self.model.set_generation_params(**self.default_params)
         
     | 
| 65 | 
         | 
| 66 | 
         
            -
             
     | 
| 67 | 
         
            -
                def forward(self, lyric: str, description: str, prompt_audio_path: os.PathLike = None, params = dict()):
         
     | 
| 68 | 
         
             
                    params = {**self.default_params, **params}
         
     | 
| 69 | 
         
             
                    self.model.set_generation_params(**params)
         
     | 
| 70 | 
         | 
| 71 | 
         
            -
                    if prompt_audio_path is None:
         
     | 
| 72 | 
         
            -
                        pmt_wav, vocal_wav, bgm_wav = None, None, None
         
     | 
| 73 | 
         
            -
                    else:
         
     | 
| 74 | 
         
             
                        pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 75 | 
         | 
| 76 | 
         
             
                    generate_inp = {
         
     | 
| 77 | 
         
             
                        'lyrics': [lyric.replace("  ", " ")],
         
     | 
| 
         @@ -79,6 +95,7 @@ class LeVoInference(torch.nn.Module): 
     | 
|
| 79 | 
         
             
                        'melody_wavs': pmt_wav,
         
     | 
| 80 | 
         
             
                        'vocal_wavs': vocal_wav,
         
     | 
| 81 | 
         
             
                        'bgm_wavs': bgm_wav,
         
     | 
| 
         | 
|
| 82 | 
         
             
                    }
         
     | 
| 83 | 
         | 
| 84 | 
         
             
                    with torch.autocast(device_type="cuda", dtype=torch.float16):
         
     | 
| 
         @@ -91,38 +108,3 @@ class LeVoInference(torch.nn.Module): 
     | 
|
| 91 | 
         
             
                        wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)
         
     | 
| 92 | 
         | 
| 93 | 
         
             
                    return wav_seperate[0]
         
     | 
| 94 | 
         
            -
             
     | 
| 95 | 
         
            -
            def build_levo_inference():
         
     | 
| 96 | 
         
            -
                cfg_path = './conf/infer.yaml'
         
     | 
| 97 | 
         
            -
                return LeVoInference(cfg_path)
         
     | 
| 98 | 
         
            -
             
     | 
| 99 | 
         
            -
            if __name__ == '__main__':
         
     | 
| 100 | 
         
            -
                import sys
         
     | 
| 101 | 
         
            -
                import os
         
     | 
| 102 | 
         
            -
                import time
         
     | 
| 103 | 
         
            -
                import json
         
     | 
| 104 | 
         
            -
                import torchaudio
         
     | 
| 105 | 
         
            -
             
     | 
| 106 | 
         
            -
                cfg_path = sys.argv[1]
         
     | 
| 107 | 
         
            -
                save_dir = sys.argv[2]
         
     | 
| 108 | 
         
            -
                input_jsonl = sys.argv[3]
         
     | 
| 109 | 
         
            -
             
     | 
| 110 | 
         
            -
                model = LeVoInference(cfg_path)
         
     | 
| 111 | 
         
            -
                
         
     | 
| 112 | 
         
            -
                os.makedirs(save_dir + "/audios", exist_ok=True)
         
     | 
| 113 | 
         
            -
             
     | 
| 114 | 
         
            -
                with open(input_jsonl, "r") as fp:
         
     | 
| 115 | 
         
            -
                    lines = fp.readlines()
         
     | 
| 116 | 
         
            -
             
     | 
| 117 | 
         
            -
                for line in lines:
         
     | 
| 118 | 
         
            -
                    item = json.loads(line)
         
     | 
| 119 | 
         
            -
                    target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
         
     | 
| 120 | 
         
            -
                    descriptions = item["descriptions"]
         
     | 
| 121 | 
         
            -
                    lyric = item["gt_lyric"]
         
     | 
| 122 | 
         
            -
                    prompt_audio_path = item['prompt_audio_path']
         
     | 
| 123 | 
         
            -
                    
         
     | 
| 124 | 
         
            -
                    wav = model(lyric, descriptions, prompt_audio_path)
         
     | 
| 125 | 
         
            -
             
     | 
| 126 | 
         
            -
                    torchaudio.save(target_wav_name, wav.cpu().float(), model.cfg.sample_rate)
         
     | 
| 127 | 
         
            -
             
     | 
| 128 | 
         
            -
             
     | 
| 
         | 
|
| 18 | 
         | 
| 19 | 
         | 
| 20 | 
         
             
            class LeVoInference(torch.nn.Module):
         
     | 
| 21 | 
         
            +
                def __init__(self, ckpt_path):
         
     | 
| 22 | 
         
             
                    super().__init__()
         
     | 
| 23 | 
         | 
| 24 | 
         
             
                    torch.backends.cudnn.enabled = False 
         
     | 
| 
         | 
|
| 27 | 
         
             
                    OmegaConf.register_new_resolver("get_fname", lambda: 'default')
         
     | 
| 28 | 
         
             
                    OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
         
     | 
| 29 | 
         | 
| 30 | 
         
            +
                    cfg_path = os.path.join(ckpt_path, 'config.yaml')
         
     | 
| 31 | 
         
            +
                    pt_path = os.path.join(ckpt_path, 'model.pt')
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
             
                    self.cfg = OmegaConf.load(cfg_path)
         
     | 
| 34 | 
         
             
                    self.cfg.mode = 'inference'
         
     | 
| 35 | 
         
             
                    self.max_duration = self.cfg.max_dur
         
     | 
| 36 | 
         | 
| 37 | 
         
             
                    # Define model or load pretrained model
         
     | 
| 38 | 
         
            +
                    model_light = CodecLM_PL(self.cfg, pt_path)
         
     | 
| 39 | 
         | 
| 40 | 
         
             
                    model_light = model_light.eval().cuda()
         
     | 
| 41 | 
         
             
                    model_light.audiolm.cfg = self.cfg
         
     | 
| 
         | 
|
| 66 | 
         | 
| 67 | 
         
             
                    self.model.set_generation_params(**self.default_params)
         
     | 
| 68 | 
         | 
| 69 | 
         
            +
                def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, params = dict()):
         
     | 
| 
         | 
|
| 70 | 
         
             
                    params = {**self.default_params, **params}
         
     | 
| 71 | 
         
             
                    self.model.set_generation_params(**params)
         
     | 
| 72 | 
         | 
| 73 | 
         
            +
                    if prompt_audio_path is not None:
         
     | 
| 
         | 
|
| 
         | 
|
| 74 | 
         
             
                        pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
         
     | 
| 75 | 
         
            +
                        melody_is_wav = True
         
     | 
| 76 | 
         
            +
                    elif genre is not None and auto_prompt_path is not None:
         
     | 
| 77 | 
         
            +
                        auto_prompt = torch.load(auto_prompt_path)
         
     | 
| 78 | 
         
            +
                        if genre == "Auto": 
         
     | 
| 79 | 
         
            +
                            prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
         
     | 
| 80 | 
         
            +
                        else:
         
     | 
| 81 | 
         
            +
                            prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
         
     | 
| 82 | 
         
            +
                        pmt_wav = prompt_token[:,[0],:]
         
     | 
| 83 | 
         
            +
                        vocal_wav = prompt_token[:,[1],:]
         
     | 
| 84 | 
         
            +
                        bgm_wav = prompt_token[:,[2],:]
         
     | 
| 85 | 
         
            +
                        melody_is_wav = False
         
     | 
| 86 | 
         
            +
                    else:
         
     | 
| 87 | 
         
            +
                        pmt_wav = None
         
     | 
| 88 | 
         
            +
                        vocal_wav = None
         
     | 
| 89 | 
         
            +
                        bgm_wav = None
         
     | 
| 90 | 
         
            +
                        melody_is_wav = True
         
     | 
| 91 | 
         | 
| 92 | 
         
             
                    generate_inp = {
         
     | 
| 93 | 
         
             
                        'lyrics': [lyric.replace("  ", " ")],
         
     | 
| 
         | 
|
| 95 | 
         
             
                        'melody_wavs': pmt_wav,
         
     | 
| 96 | 
         
             
                        'vocal_wavs': vocal_wav,
         
     | 
| 97 | 
         
             
                        'bgm_wavs': bgm_wav,
         
     | 
| 98 | 
         
            +
                        'melody_is_wav': melody_is_wav,
         
     | 
| 99 | 
         
             
                    }
         
     | 
| 100 | 
         | 
| 101 | 
         
             
                    with torch.autocast(device_type="cuda", dtype=torch.float16):
         
     | 
| 
         | 
|
| 108 | 
         
             
                        wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)
         
     | 
| 109 | 
         | 
| 110 | 
         
             
                    return wav_seperate[0]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        sample/description/emotion.txt
    ADDED
    
    | 
         @@ -0,0 +1,8 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            sad
         
     | 
| 2 | 
         
            +
            emotional
         
     | 
| 3 | 
         
            +
            angry
         
     | 
| 4 | 
         
            +
            happy
         
     | 
| 5 | 
         
            +
            uplifting
         
     | 
| 6 | 
         
            +
            intense
         
     | 
| 7 | 
         
            +
            romantic
         
     | 
| 8 | 
         
            +
            melancholic
         
     | 
    	
        sample/description/gender.txt
    ADDED
    
    | 
         @@ -0,0 +1,2 @@ 
     | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            female
         
     | 
| 2 | 
         
            +
            male
         
     | 
    	
        sample/description/genre.txt
    ADDED
    
    | 
         @@ -0,0 +1,27 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            pop
         
     | 
| 2 | 
         
            +
            electronic
         
     | 
| 3 | 
         
            +
            hip hop
         
     | 
| 4 | 
         
            +
            rock
         
     | 
| 5 | 
         
            +
            jazz
         
     | 
| 6 | 
         
            +
            blues
         
     | 
| 7 | 
         
            +
            classical
         
     | 
| 8 | 
         
            +
            rap
         
     | 
| 9 | 
         
            +
            country
         
     | 
| 10 | 
         
            +
            classic rock
         
     | 
| 11 | 
         
            +
            hard rock
         
     | 
| 12 | 
         
            +
            folk
         
     | 
| 13 | 
         
            +
            soul
         
     | 
| 14 | 
         
            +
            dance, electronic
         
     | 
| 15 | 
         
            +
            rockabilly
         
     | 
| 16 | 
         
            +
            dance, dancepop, house, pop
         
     | 
| 17 | 
         
            +
            reggae
         
     | 
| 18 | 
         
            +
            experimental
         
     | 
| 19 | 
         
            +
            dance, pop
         
     | 
| 20 | 
         
            +
            dance, deephouse, electronic
         
     | 
| 21 | 
         
            +
            k-pop
         
     | 
| 22 | 
         
            +
            experimental pop
         
     | 
| 23 | 
         
            +
            pop punk
         
     | 
| 24 | 
         
            +
            rock and roll
         
     | 
| 25 | 
         
            +
            R&B
         
     | 
| 26 | 
         
            +
            varies
         
     | 
| 27 | 
         
            +
            pop rock
         
     | 
    	
        sample/description/instrument.txt
    ADDED
    
    | 
         @@ -0,0 +1,40 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            synthesizer and piano
         
     | 
| 2 | 
         
            +
            piano and drums
         
     | 
| 3 | 
         
            +
            piano and synthesizer
         
     | 
| 4 | 
         
            +
            synthesizer and drums
         
     | 
| 5 | 
         
            +
            piano and strings
         
     | 
| 6 | 
         
            +
            guitar and drums
         
     | 
| 7 | 
         
            +
            guitar and piano
         
     | 
| 8 | 
         
            +
            piano and double bass
         
     | 
| 9 | 
         
            +
            piano and guitar
         
     | 
| 10 | 
         
            +
            acoustic guitar and piano
         
     | 
| 11 | 
         
            +
            acoustic guitar and synthesizer
         
     | 
| 12 | 
         
            +
            synthesizer and guitar
         
     | 
| 13 | 
         
            +
            piano and saxophone
         
     | 
| 14 | 
         
            +
            saxophone and piano
         
     | 
| 15 | 
         
            +
            piano and violin
         
     | 
| 16 | 
         
            +
            electric guitar and drums
         
     | 
| 17 | 
         
            +
            acoustic guitar and drums
         
     | 
| 18 | 
         
            +
            synthesizer
         
     | 
| 19 | 
         
            +
            guitar and fiddle
         
     | 
| 20 | 
         
            +
            guitar and harmonica
         
     | 
| 21 | 
         
            +
            synthesizer and acoustic guitar
         
     | 
| 22 | 
         
            +
            beats
         
     | 
| 23 | 
         
            +
            piano
         
     | 
| 24 | 
         
            +
            acoustic guitar and fiddle
         
     | 
| 25 | 
         
            +
            brass and piano
         
     | 
| 26 | 
         
            +
            bass and drums
         
     | 
| 27 | 
         
            +
            violin
         
     | 
| 28 | 
         
            +
            acoustic guitar and harmonica
         
     | 
| 29 | 
         
            +
            piano and cello
         
     | 
| 30 | 
         
            +
            saxophone and trumpet
         
     | 
| 31 | 
         
            +
            guitar and banjo
         
     | 
| 32 | 
         
            +
            guitar and synthesizer
         
     | 
| 33 | 
         
            +
            saxophone
         
     | 
| 34 | 
         
            +
            violin and piano
         
     | 
| 35 | 
         
            +
            synthesizer and bass
         
     | 
| 36 | 
         
            +
            synthesizer and electric guitar
         
     | 
| 37 | 
         
            +
            electric guitar and piano
         
     | 
| 38 | 
         
            +
            beats and piano
         
     | 
| 39 | 
         
            +
            synthesizer and
         
     | 
| 40 | 
         
            +
            guitar
         
     | 
    	
        sample/description/timbre.txt
    ADDED
    
    | 
         @@ -0,0 +1,7 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            dark
         
     | 
| 2 | 
         
            +
            bright
         
     | 
| 3 | 
         
            +
            warm
         
     | 
| 4 | 
         
            +
            rock
         
     | 
| 5 | 
         
            +
            varies
         
     | 
| 6 | 
         
            +
            soft
         
     | 
| 7 | 
         
            +
            vocal
         
     | 
    	
        sample/lyric.jsonl
    DELETED
    
    | 
         @@ -1 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            {"idx": "01_节奏蓝调", "descriptions": "female, dark, pop, sad, piano and drums, the bpm is 125.", "gt_lyric": "[intro-short] ; [verse] 夜晚的街灯闪烁.我漫步在熟悉的角落.回忆像潮水般涌来.你的笑容如此清晰.在心头无法抹去.那些曾经的甜蜜.如今只剩我独自回忆 ; [bridge] 手机屏幕亮起.是你发来的消息.简单的几个字.却让我泪流满面.曾经的拥抱温暖.如今却变得遥远.我多想回到从前.重新拥有你的陪伴 ; [chorus] 回忆的温度还在.你却已不在.我的心被爱填满.却又被思念刺痛.R&B的节奏奏响.我的心却在流浪.没有你的日子.我该如何继续向前 ; [outro-short]", "prompt_audio_path": "sample/prompt.wav"}
         
     | 
| 
         | 
|
| 
         | 
    	
        sample/lyrics.jsonl
    ADDED
    
    | 
         @@ -0,0 +1,4 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {"idx": "sample_01_autoprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "auto_prompt_audio_type": "Auto"}
         
     | 
| 2 | 
         
            +
            {"idx": "sample_01_noprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
         
     | 
| 3 | 
         
            +
            {"idx": "sample_01_textprompt", "descriptions": "female, dark, pop, sad, piano and drums, the bpm is 125.", "gt_lyric": "[intro-short] ;  [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
         
     | 
| 4 | 
         
            +
            {"idx": "sample_01_audioprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "prompt_audio_path": "sample/sample_prompt_audio.wav"}
         
     | 
    	
        sample/sample_prompt_audio.wav
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:2068592b00263f7c0b0f1d82a882d7738730ace3e04f2d889d06ff983ad6d618
         
     | 
| 3 | 
         
            +
            size 3845542
         
     |