Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -67,18 +67,19 @@ import time | |
| 67 | 
             
            import copy
         | 
| 68 | 
             
            from collections import Counter
         | 
| 69 | 
             
            from models.soundstream_hubert_new import SoundStream
         | 
| 70 | 
            -
             | 
| 71 | 
            -
            #from post_process_audio import replace_low_freq_with_energy_matched # removed post process
         | 
| 72 |  | 
| 73 | 
             
            device = "cuda:0"
         | 
| 74 |  | 
|  | |
|  | |
| 75 | 
             
            model = AutoModelForCausalLM.from_pretrained(
         | 
| 76 | 
            -
                "m-a-p/YuE-s1-7B-anneal-en- | 
| 77 | 
             
                torch_dtype=torch.float16,
         | 
| 78 | 
             
                attn_implementation="flash_attention_2",
         | 
| 79 | 
            -
                low_cpu_mem_usage=True,
         | 
| 80 | 
             
            ).to(device)
         | 
| 81 | 
             
            model.eval()
         | 
|  | |
| 82 |  | 
| 83 | 
             
            basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
         | 
| 84 | 
             
            resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
         | 
| @@ -92,9 +93,61 @@ codec_model = eval(model_config.generator.name)(**model_config.generator.config) | |
| 92 | 
             
            parameter_dict = torch.load(resume_path, map_location='cpu')
         | 
| 93 | 
             
            codec_model.load_state_dict(parameter_dict['codec_model'])
         | 
| 94 | 
             
            codec_model.eval()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 95 |  | 
| 96 |  | 
| 97 | 
             
            @spaces.GPU(duration=120)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 98 | 
             
            def generate_music(
         | 
| 99 | 
             
                    max_new_tokens=5,
         | 
| 100 | 
             
                    run_n_segments=2,
         | 
| @@ -107,6 +160,11 @@ def generate_music( | |
| 107 | 
             
                    cuda_idx=0,
         | 
| 108 | 
             
                    rescale=False,
         | 
| 109 | 
             
            ):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 110 | 
             
                if use_audio_prompt and not audio_prompt_path:
         | 
| 111 | 
             
                    raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
         | 
| 112 | 
             
                cuda_idx = cuda_idx
         | 
| @@ -116,31 +174,7 @@ def generate_music( | |
| 116 | 
             
                    stage1_output_dir = os.path.join(output_dir, f"stage1")
         | 
| 117 | 
             
                    os.makedirs(stage1_output_dir, exist_ok=True)
         | 
| 118 |  | 
| 119 | 
            -
             | 
| 120 | 
            -
                        def __init__(self, start_id, end_id):
         | 
| 121 | 
            -
                            self.blocked_token_ids = list(range(start_id, end_id))
         | 
| 122 | 
            -
             | 
| 123 | 
            -
                        def __call__(self, input_ids, scores):
         | 
| 124 | 
            -
                            scores[:, self.blocked_token_ids] = -float("inf")
         | 
| 125 | 
            -
                            return scores
         | 
| 126 | 
            -
             | 
| 127 | 
            -
                    def load_audio_mono(filepath, sampling_rate=16000):
         | 
| 128 | 
            -
                        audio, sr = torchaudio.load(filepath)
         | 
| 129 | 
            -
                        # Convert to mono
         | 
| 130 | 
            -
                        audio = torch.mean(audio, dim=0, keepdim=True)
         | 
| 131 | 
            -
                        # Resample if needed
         | 
| 132 | 
            -
                        if sr != sampling_rate:
         | 
| 133 | 
            -
                            resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
         | 
| 134 | 
            -
                            audio = resampler(audio)
         | 
| 135 | 
            -
                        return audio
         | 
| 136 | 
            -
             | 
| 137 | 
            -
                    def split_lyrics(lyrics: str):
         | 
| 138 | 
            -
                        pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
         | 
| 139 | 
            -
                        segments = re.findall(pattern, lyrics, re.DOTALL)
         | 
| 140 | 
            -
                        structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
         | 
| 141 | 
            -
                        return structured_lyrics
         | 
| 142 | 
            -
             | 
| 143 | 
            -
                    # Call the function and print the result
         | 
| 144 | 
             
                    stage1_output_set = []
         | 
| 145 |  | 
| 146 | 
             
                    genres = genre_txt.strip()
         | 
| @@ -151,16 +185,15 @@ def generate_music( | |
| 151 | 
             
                    prompt_texts += lyrics
         | 
| 152 |  | 
| 153 | 
             
                    random_id = uuid.uuid4()
         | 
| 154 | 
            -
                     | 
| 155 | 
            -
             | 
|  | |
| 156 | 
             
                    top_p = 0.93
         | 
| 157 | 
             
                    temperature = 1.0
         | 
| 158 | 
             
                    repetition_penalty = 1.2
         | 
| 159 | 
            -
                    # special tokens
         | 
| 160 | 
             
                    start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
         | 
| 161 | 
             
                    end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
         | 
| 162 |  | 
| 163 | 
            -
                    raw_output = None
         | 
| 164 |  | 
| 165 | 
             
                    # Format text prompt
         | 
| 166 | 
             
                    run_n_segments = min(run_n_segments + 1, len(lyrics))
         | 
| @@ -169,7 +202,7 @@ def generate_music( | |
| 169 |  | 
| 170 | 
             
                    for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
         | 
| 171 | 
             
                        section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
         | 
| 172 | 
            -
                        guidance_scale = 1.5 if i <= 1 else 1.2
         | 
| 173 | 
             
                        if i == 0:
         | 
| 174 | 
             
                            continue
         | 
| 175 | 
             
                        if i == 1:
         | 
| @@ -196,30 +229,17 @@ def generate_music( | |
| 196 |  | 
| 197 | 
             
                        prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
         | 
| 198 | 
             
                        input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
         | 
|  | |
| 199 | 
             
                        # Use window slicing in case output sequence exceeds the context of model
         | 
| 200 | 
             
                        max_context = 16384 - max_new_tokens - 1
         | 
| 201 | 
             
                        if input_ids.shape[-1] > max_context:
         | 
| 202 | 
             
                            print(
         | 
| 203 | 
             
                                f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
         | 
| 204 | 
             
                            input_ids = input_ids[:, -(max_context):]
         | 
| 205 | 
            -
             | 
| 206 | 
            -
             | 
| 207 | 
            -
             | 
| 208 | 
            -
             | 
| 209 | 
            -
                                min_new_tokens=100,
         | 
| 210 | 
            -
                                do_sample=True,
         | 
| 211 | 
            -
                                top_p=top_p,
         | 
| 212 | 
            -
                                temperature=temperature,
         | 
| 213 | 
            -
                                repetition_penalty=repetition_penalty,
         | 
| 214 | 
            -
                                eos_token_id=mmtokenizer.eoa,
         | 
| 215 | 
            -
                                pad_token_id=mmtokenizer.eoa,
         | 
| 216 | 
            -
                                logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
         | 
| 217 | 
            -
                                guidance_scale=guidance_scale,
         | 
| 218 | 
            -
                                use_cache=True
         | 
| 219 | 
            -
                            )
         | 
| 220 | 
            -
                            if output_seq[0][-1].item() != mmtokenizer.eoa:
         | 
| 221 | 
            -
                                tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
         | 
| 222 | 
            -
                                output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
         | 
| 223 | 
             
                        if i > 1:
         | 
| 224 | 
             
                            raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
         | 
| 225 | 
             
                        else:
         | 
| @@ -240,7 +260,7 @@ def generate_music( | |
| 240 | 
             
                        codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
         | 
| 241 | 
             
                        if codec_ids[0] == 32016:
         | 
| 242 | 
             
                            codec_ids = codec_ids[1:]
         | 
| 243 | 
            -
                        codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
         | 
| 244 | 
             
                        vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
         | 
| 245 | 
             
                        vocals.append(vocals_ids)
         | 
| 246 | 
             
                        instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
         | 
| @@ -282,7 +302,7 @@ def generate_music( | |
| 282 | 
             
                        decoded_waveform = decoded_waveform.cpu().squeeze(0)
         | 
| 283 | 
             
                        decodec_rlt.append(torch.as_tensor(decoded_waveform))
         | 
| 284 | 
             
                        decodec_rlt = torch.cat(decodec_rlt, dim=-1)
         | 
| 285 | 
            -
                        save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
         | 
| 286 | 
             
                        tracks.append(save_path)
         | 
| 287 | 
             
                        save_audio(decodec_rlt, save_path, 16000)
         | 
| 288 | 
             
                    # mix tracks
         | 
| @@ -306,7 +326,11 @@ def generate_music( | |
| 306 |  | 
| 307 |  | 
| 308 | 
             
            def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
         | 
| 309 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
| 310 | 
             
                try:
         | 
| 311 | 
             
                    mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
         | 
| 312 | 
             
                                           cuda_idx=0, max_new_tokens=max_new_tokens)
         | 
| @@ -315,10 +339,10 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens= | |
| 315 | 
             
                    gr.Warning("An Error Occured: " + str(e))
         | 
| 316 | 
             
                    return None, None, None
         | 
| 317 | 
             
                finally:
         | 
| 318 | 
            -
                    print("Temporary files deleted.")
         | 
| 319 |  | 
| 320 |  | 
| 321 | 
            -
            # Gradio
         | 
| 322 | 
             
            with gr.Blocks() as demo:
         | 
| 323 | 
             
                with gr.Column():
         | 
| 324 | 
             
                    gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
         | 
| @@ -352,19 +376,6 @@ with gr.Blocks() as demo: | |
| 352 |  | 
| 353 | 
             
                    gr.Examples(
         | 
| 354 | 
             
                        examples=[
         | 
| 355 | 
            -
            #                 ["Rap-Rock Hybrid Punk basslines Scream-rap fusion Crowd chant vocals Distorted turntable scratches Rebel male vocal", 
         | 
| 356 | 
            -
            #                     """[verse]  
         | 
| 357 | 
            -
            # I'm the glitch in the algorithm's perfect face  
         | 
| 358 | 
            -
            # Spit code red in 8-bit, corrupt the marketplace  
         | 
| 359 | 
            -
            # Leather jacket pixels in a digital storm  
         | 
| 360 | 
            -
            # Got meme knives that go viral, keep the normies warm  
         | 
| 361 | 
            -
             | 
| 362 | 
            -
            # [chorus]  
         | 
| 363 | 
            -
            # BREAK-CORE! (Break-core!)  
         | 
| 364 | 
            -
            # Code-slicin' through the mainframe's bore  
         | 
| 365 | 
            -
            # FAKE WAR! (Fake war!)  
         | 
| 366 | 
            -
            # Trend-detonate, I'm the feedback roar  
         | 
| 367 | 
            -
            #                     """],
         | 
| 368 | 
             
                            [
         | 
| 369 | 
             
                                "rap piano street tough piercing vocal hip-hop synthesizer clear vocal male",
         | 
| 370 | 
             
                                """[verse]
         | 
| @@ -415,5 +426,5 @@ Locked inside my mind, hot flame. | |
| 415 | 
             
                    outputs=[music_out, vocal_out, instrumental_out]
         | 
| 416 | 
             
                )
         | 
| 417 | 
             
                gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
         | 
| 418 | 
            -
             | 
| 419 | 
             
            demo.queue().launch(show_error=True)
         | 
|  | |
| 67 | 
             
            import copy
         | 
| 68 | 
             
            from collections import Counter
         | 
| 69 | 
             
            from models.soundstream_hubert_new import SoundStream
         | 
| 70 | 
            +
             | 
|  | |
| 71 |  | 
| 72 | 
             
            device = "cuda:0"
         | 
| 73 |  | 
| 74 | 
            +
            # Load model and tokenizer outside the generation function (load once)
         | 
| 75 | 
            +
            print("Loading model...")
         | 
| 76 | 
             
            model = AutoModelForCausalLM.from_pretrained(
         | 
| 77 | 
            +
                "m-a-p/YuE-s1-7B-anneal-en-cot", # "m-a-p/YuE-s1-7B-anneal-en-icl",
         | 
| 78 | 
             
                torch_dtype=torch.float16,
         | 
| 79 | 
             
                attn_implementation="flash_attention_2",
         | 
|  | |
| 80 | 
             
            ).to(device)
         | 
| 81 | 
             
            model.eval()
         | 
| 82 | 
            +
            print("Model loaded.")
         | 
| 83 |  | 
| 84 | 
             
            basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
         | 
| 85 | 
             
            resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
         | 
|  | |
| 93 | 
             
            parameter_dict = torch.load(resume_path, map_location='cpu')
         | 
| 94 | 
             
            codec_model.load_state_dict(parameter_dict['codec_model'])
         | 
| 95 | 
             
            codec_model.eval()
         | 
| 96 | 
            +
            print("Codec model loaded.")
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            class BlockTokenRangeProcessor(LogitsProcessor):
         | 
| 100 | 
            +
                def __init__(self, start_id, end_id):
         | 
| 101 | 
            +
                    self.blocked_token_ids = list(range(start_id, end_id))
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def __call__(self, input_ids, scores):
         | 
| 104 | 
            +
                    scores[:, self.blocked_token_ids] = -float("inf")
         | 
| 105 | 
            +
                    return scores
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            def load_audio_mono(filepath, sampling_rate=16000):
         | 
| 108 | 
            +
                audio, sr = torchaudio.load(filepath)
         | 
| 109 | 
            +
                # Convert to mono
         | 
| 110 | 
            +
                audio = torch.mean(audio, dim=0, keepdim=True)
         | 
| 111 | 
            +
                # Resample if needed
         | 
| 112 | 
            +
                if sr != sampling_rate:
         | 
| 113 | 
            +
                    resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
         | 
| 114 | 
            +
                    audio = resampler(audio)
         | 
| 115 | 
            +
                return audio
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            def split_lyrics(lyrics: str):
         | 
| 118 | 
            +
                pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
         | 
| 119 | 
            +
                segments = re.findall(pattern, lyrics, re.DOTALL)
         | 
| 120 | 
            +
                structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
         | 
| 121 | 
            +
                return structured_lyrics
         | 
| 122 |  | 
| 123 |  | 
| 124 | 
             
            @spaces.GPU(duration=120)
         | 
| 125 | 
            +
            def model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale):
         | 
| 126 | 
            +
                """
         | 
| 127 | 
            +
                Performs model inference to generate music tokens.
         | 
| 128 | 
            +
                This function is decorated with @spaces.GPU for GPU usage in Gradio Spaces.
         | 
| 129 | 
            +
                """
         | 
| 130 | 
            +
                with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
         | 
| 131 | 
            +
                    output_seq = model.generate(
         | 
| 132 | 
            +
                        input_ids=input_ids,
         | 
| 133 | 
            +
                        max_new_tokens=max_new_tokens,
         | 
| 134 | 
            +
                        min_new_tokens=100, # Keep min_new_tokens to avoid short generations
         | 
| 135 | 
            +
                        do_sample=True,
         | 
| 136 | 
            +
                        top_p=top_p,
         | 
| 137 | 
            +
                        temperature=temperature,
         | 
| 138 | 
            +
                        repetition_penalty=repetition_penalty,
         | 
| 139 | 
            +
                        eos_token_id=mmtokenizer.eoa,
         | 
| 140 | 
            +
                        pad_token_id=mmtokenizer.eoa,
         | 
| 141 | 
            +
                        logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
         | 
| 142 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 143 | 
            +
                        use_cache=True
         | 
| 144 | 
            +
                    )
         | 
| 145 | 
            +
                    if output_seq[0][-1].item() != mmtokenizer.eoa:
         | 
| 146 | 
            +
                        tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
         | 
| 147 | 
            +
                        output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
         | 
| 148 | 
            +
                return output_seq
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
             
            def generate_music(
         | 
| 152 | 
             
                    max_new_tokens=5,
         | 
| 153 | 
             
                    run_n_segments=2,
         | 
|  | |
| 160 | 
             
                    cuda_idx=0,
         | 
| 161 | 
             
                    rescale=False,
         | 
| 162 | 
             
            ):
         | 
| 163 | 
            +
                """
         | 
| 164 | 
            +
                Generates music based on given genre and lyrics, optionally using an audio prompt.
         | 
| 165 | 
            +
                This function orchestrates the music generation process, including prompt formatting,
         | 
| 166 | 
            +
                model inference, and audio post-processing.
         | 
| 167 | 
            +
                """
         | 
| 168 | 
             
                if use_audio_prompt and not audio_prompt_path:
         | 
| 169 | 
             
                    raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
         | 
| 170 | 
             
                cuda_idx = cuda_idx
         | 
|  | |
| 174 | 
             
                    stage1_output_dir = os.path.join(output_dir, f"stage1")
         | 
| 175 | 
             
                    os.makedirs(stage1_output_dir, exist_ok=True)
         | 
| 176 |  | 
| 177 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 178 | 
             
                    stage1_output_set = []
         | 
| 179 |  | 
| 180 | 
             
                    genres = genre_txt.strip()
         | 
|  | |
| 185 | 
             
                    prompt_texts += lyrics
         | 
| 186 |  | 
| 187 | 
             
                    random_id = uuid.uuid4()
         | 
| 188 | 
            +
                    raw_output = None
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    # Decoding config (moved here for better readability)
         | 
| 191 | 
             
                    top_p = 0.93
         | 
| 192 | 
             
                    temperature = 1.0
         | 
| 193 | 
             
                    repetition_penalty = 1.2
         | 
|  | |
| 194 | 
             
                    start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
         | 
| 195 | 
             
                    end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
         | 
| 196 |  | 
|  | |
| 197 |  | 
| 198 | 
             
                    # Format text prompt
         | 
| 199 | 
             
                    run_n_segments = min(run_n_segments + 1, len(lyrics))
         | 
|  | |
| 202 |  | 
| 203 | 
             
                    for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
         | 
| 204 | 
             
                        section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
         | 
| 205 | 
            +
                        guidance_scale = 1.5 if i <= 1 else 1.2 # Guidance scale adjusted based on segment index
         | 
| 206 | 
             
                        if i == 0:
         | 
| 207 | 
             
                            continue
         | 
| 208 | 
             
                        if i == 1:
         | 
|  | |
| 229 |  | 
| 230 | 
             
                        prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
         | 
| 231 | 
             
                        input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
         | 
| 232 | 
            +
             | 
| 233 | 
             
                        # Use window slicing in case output sequence exceeds the context of model
         | 
| 234 | 
             
                        max_context = 16384 - max_new_tokens - 1
         | 
| 235 | 
             
                        if input_ids.shape[-1] > max_context:
         | 
| 236 | 
             
                            print(
         | 
| 237 | 
             
                                f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
         | 
| 238 | 
             
                            input_ids = input_ids[:, -(max_context):]
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                        output_seq = model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 243 | 
             
                        if i > 1:
         | 
| 244 | 
             
                            raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
         | 
| 245 | 
             
                        else:
         | 
|  | |
| 260 | 
             
                        codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
         | 
| 261 | 
             
                        if codec_ids[0] == 32016:
         | 
| 262 | 
             
                            codec_ids = codec_ids[1:]
         | 
| 263 | 
            +
                        codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)] # Ensure even length for reshape
         | 
| 264 | 
             
                        vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
         | 
| 265 | 
             
                        vocals.append(vocals_ids)
         | 
| 266 | 
             
                        instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
         | 
|  | |
| 302 | 
             
                        decoded_waveform = decoded_waveform.cpu().squeeze(0)
         | 
| 303 | 
             
                        decodec_rlt.append(torch.as_tensor(decoded_waveform))
         | 
| 304 | 
             
                        decodec_rlt = torch.cat(decodec_rlt, dim=-1)
         | 
| 305 | 
            +
                        save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3") # Save as mp3 for gradio
         | 
| 306 | 
             
                        tracks.append(save_path)
         | 
| 307 | 
             
                        save_audio(decodec_rlt, save_path, 16000)
         | 
| 308 | 
             
                    # mix tracks
         | 
|  | |
| 326 |  | 
| 327 |  | 
| 328 | 
             
            def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
         | 
| 329 | 
            +
                """
         | 
| 330 | 
            +
                Gradio interface function to generate music.
         | 
| 331 | 
            +
                This function takes genre, lyrics, and generation parameters from Gradio inputs,
         | 
| 332 | 
            +
                calls the music generation pipeline, and returns the audio outputs.
         | 
| 333 | 
            +
                """
         | 
| 334 | 
             
                try:
         | 
| 335 | 
             
                    mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
         | 
| 336 | 
             
                                           cuda_idx=0, max_new_tokens=max_new_tokens)
         | 
|  | |
| 339 | 
             
                    gr.Warning("An Error Occured: " + str(e))
         | 
| 340 | 
             
                    return None, None, None
         | 
| 341 | 
             
                finally:
         | 
| 342 | 
            +
                    print("Temporary files deleted.") # This message is printed regardless of success/failure
         | 
| 343 |  | 
| 344 |  | 
| 345 | 
            +
            # Gradio Interface
         | 
| 346 | 
             
            with gr.Blocks() as demo:
         | 
| 347 | 
             
                with gr.Column():
         | 
| 348 | 
             
                    gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
         | 
|  | |
| 376 |  | 
| 377 | 
             
                    gr.Examples(
         | 
| 378 | 
             
                        examples=[
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 379 | 
             
                            [
         | 
| 380 | 
             
                                "rap piano street tough piercing vocal hip-hop synthesizer clear vocal male",
         | 
| 381 | 
             
                                """[verse]
         | 
|  | |
| 426 | 
             
                    outputs=[music_out, vocal_out, instrumental_out]
         | 
| 427 | 
             
                )
         | 
| 428 | 
             
                gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
         | 
| 429 | 
            +
             | 
| 430 | 
             
            demo.queue().launch(show_error=True)
         | 
 
			
