import os import subprocess # Install flash attention subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) import spaces import os import torch import numpy as np from omegaconf import OmegaConf import torchaudio from torchaudio.transforms import Resample import soundfile as sf import uuid from tqdm import tqdm from einops import rearrange import gradio as gr import re from collections import Counter from codecmanipulator import CodecManipulator from mmtokenizer import _MMSentencePieceTokenizer from transformers import AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList # from models.soundstream_hubert_new import SoundStream from vocoder import build_codec_model, process_audio from post_process_audio import replace_low_freq_with_energy_matched # Initialize global variables and models device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model") codectool = CodecManipulator("xcodec", 0, 1) codectool_stage2 = CodecManipulator("xcodec", 0, 8) # Load models once at startup def load_models(): # Stage 1 Model stage1_model = AutoModelForCausalLM.from_pretrained( "m-a-p/YuE-s1-7B-anneal-en-cot", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ).to(device) stage1_model.eval() # Stage 2 Model stage2_model = AutoModelForCausalLM.from_pretrained( "m-a-p/YuE-s2-1B-general", torch_dtype=torch.float16, attn_implementation="flash_attention_2" ).to(device) stage2_model.eval() # Codec Model model_config = OmegaConf.load('./xcodec_mini_infer/final_ckpt/config.yaml') codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device) parameter_dict = torch.load('./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth', map_location='cpu') codec_model.load_state_dict(parameter_dict['codec_model']) codec_model.eval() return stage1_model, stage2_model, codec_model stage1_model, stage2_model, codec_model = load_models() # Helper functions def split_lyrics(lyrics): pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)" segments = re.findall(pattern, lyrics, re.DOTALL) return [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments] def load_audio_mono(filepath, sampling_rate=16000): audio, sr = torchaudio.load(filepath) audio = torch.mean(audio, dim=0, keepdim=True) # Convert to mono if sr != sampling_rate: resampler = Resample(orig_freq=sr, new_freq=sampling_rate) audio = resampler(audio) return audio def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False): folder_path = os.path.dirname(path) if not os.path.exists(folder_path): os.makedirs(folder_path) limit = 0.99 max_val = wav.abs().max() wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit) torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16) # Stage 1 Generation def stage1_generate(genres, lyrics_text, use_audio_prompt, audio_prompt_path, prompt_start_time, prompt_end_time): structured_lyrics = split_lyrics(lyrics_text) full_lyrics = "\n".join(structured_lyrics) prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"] + structured_lyrics random_id = str(uuid.uuid4()) output_dir = os.path.join("./output", random_id) os.makedirs(output_dir, exist_ok=True) stage1_output_set = [] for i, p in enumerate(tqdm(prompt_texts)): section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '') guidance_scale = 1.5 if i <= 1 else 1.2 if i == 0: continue if i == 1 and use_audio_prompt: audio_prompt = load_audio_mono(audio_prompt_path) audio_prompt.unsqueeze_(0) with torch.no_grad(): raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5) raw_codes = raw_codes.transpose(0, 1).cpu().numpy().astype(np.int16) audio_prompt_codec = codectool.npy2ids(raw_codes[0])[int(prompt_start_time * 50): int(prompt_end_time * 50)] audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa] sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]") head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids else: head_id = mmtokenizer.tokenize(prompt_texts[0]) prompt_ids = head_id + mmtokenizer.tokenize("[start_of_segment]") + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device) with torch.no_grad(): output_seq = stage1_model.generate( input_ids=prompt_ids, max_new_tokens=3000, min_new_tokens=100, do_sample=True, top_p=0.93, temperature=1.0, repetition_penalty=1.2, eos_token_id=mmtokenizer.eoa, pad_token_id=mmtokenizer.eoa, ) if i > 1: raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, prompt_ids.shape[-1]:]], dim=1) else: raw_output = output_seq # Save Stage 1 outputs ids = raw_output[0].cpu().numpy() soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist() eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist() vocals = [] instrumentals = [] for i in range(len(soa_idx)): codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]] if codec_ids[0] == 32016: codec_ids = codec_ids[1:] codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)] vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0]) vocals.append(vocals_ids) instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1]) instrumentals.append(instrumentals_ids) vocals = np.concatenate(vocals, axis=1) instrumentals = np.concatenate(instrumentals, axis=1) vocal_save_path = os.path.join(output_dir, f"vocal_{random_id}.npy") inst_save_path = os.path.join(output_dir, f"instrumental_{random_id}.npy") np.save(vocal_save_path, vocals) np.save(inst_save_path, instrumentals) stage1_output_set.append(vocal_save_path) stage1_output_set.append(inst_save_path) return stage1_output_set, output_dir # Stage 2 Generation def stage2_generate(model, prompt, batch_size=16): codec_ids = codectool.unflatten(prompt, n_quantizer=1) codec_ids = codectool.offset_tok_ids( codec_ids, global_offset=codectool.global_offset, codebook_size=codectool.codebook_size, num_codebooks=codectool.num_codebooks, ).astype(np.int32) if batch_size > 1: codec_list = [] for i in range(batch_size): idx_begin = i * 300 idx_end = (i + 1) * 300 codec_list.append(codec_ids[:, idx_begin:idx_end]) codec_ids = np.concatenate(codec_list, axis=0) prompt_ids = np.concatenate( [ np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)), codec_ids, np.tile([mmtokenizer.stage_2], (batch_size, 1)), ], axis=1 ) else: prompt_ids = np.concatenate([ np.array([mmtokenizer.soa, mmtokenizer.stage_1]), codec_ids.flatten(), np.array([mmtokenizer.stage_2]) ]).astype(np.int32) prompt_ids = prompt_ids[np.newaxis, ...] codec_ids = torch.as_tensor(codec_ids).to(device) prompt_ids = torch.as_tensor(prompt_ids).to(device) len_prompt = prompt_ids.shape[-1] block_list = LogitsProcessorList([BlockTokenRangeProcessor(0, 46358), BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)]) for frames_idx in range(codec_ids.shape[1]): cb0 = codec_ids[:, frames_idx:frames_idx + 1] prompt_ids = torch.cat([prompt_ids, cb0], dim=1) input_ids = prompt_ids with torch.no_grad(): stage2_output = model.generate( input_ids=input_ids, min_new_tokens=7, max_new_tokens=7, eos_token_id=mmtokenizer.eoa, pad_token_id=mmtokenizer.eoa, logits_processor=block_list, ) assert stage2_output.shape[1] - prompt_ids.shape[1] == 7, f"output new tokens={stage2_output.shape[1] - prompt_ids.shape[1]}" prompt_ids = stage2_output if batch_size > 1: output = prompt_ids.cpu().numpy()[:, len_prompt:] output_list = [output[i] for i in range(batch_size)] output = np.concatenate(output_list, axis=0) else: output = prompt_ids[0].cpu().numpy()[len_prompt:] return output def stage2_inference(model, stage1_output_set, output_dir, batch_size=4): stage2_result = [] for i in tqdm(range(len(stage1_output_set))): output_filename = os.path.join(output_dir, os.path.basename(stage1_output_set[i])) if os.path.exists(output_filename): continue prompt = np.load(stage1_output_set[i]).astype(np.int32) output_duration = prompt.shape[-1] // 50 // 6 * 6 num_batch = output_duration // 6 if num_batch <= batch_size: output = stage2_generate(model, prompt[:, :output_duration * 50], batch_size=num_batch) else: segments = [] num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0) for seg in range(num_segments): start_idx = seg * batch_size * 300 end_idx = min((seg + 1) * batch_size * 300, output_duration * 50) current_batch_size = batch_size if seg != num_segments - 1 or num_batch % batch_size == 0 else num_batch % batch_size segment = stage2_generate(model, prompt[:, start_idx:end_idx], batch_size=current_batch_size) segments.append(segment) output = np.concatenate(segments, axis=0) if output_duration * 50 != prompt.shape[-1]: ending = stage2_generate(model, prompt[:, output_duration * 50:], batch_size=1) output = np.concatenate([output, ending], axis=0) output = codectool_stage2.ids2npy(output) fixed_output = copy.deepcopy(output) for i, line in enumerate(output): for j, element in enumerate(line): if element < 0 or element > 1023: counter = Counter(line) most_frequant = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0] fixed_output[i, j] = most_frequant np.save(output_filename, fixed_output) stage2_result.append(output_filename) return stage2_result # Main Gradio function @spaces.GPU() def generate_music(genres, lyrics_text, use_audio_prompt, audio_prompt, start_time, end_time, progress=gr.Progress()): progress(0.1, "Running Stage 1 Generation...") stage1_output_set, output_dir = stage1_generate(genres, lyrics_text, use_audio_prompt, audio_prompt, start_time, end_time) progress(0.6, "Running Stage 2 Refinement...") stage2_result = stage2_inference(stage2_model, stage1_output_set, output_dir) progress(0.8, "Processing Audio...") vocal_decoder, inst_decoder = build_codec_model('./xcodec_mini_infer/decoders/config.yaml', './xcodec_mini_infer/decoders/decoder_131000.pth', './xcodec_mini_infer/decoders/decoder_151000.pth') vocoder_output_dir = os.path.join(output_dir, "vocoder") os.makedirs(vocoder_output_dir, exist_ok=True) for npy in stage2_result: if 'instrumental' in npy: process_audio(npy, os.path.join(vocoder_output_dir, 'instrumental.mp3'), False, None, inst_decoder, codec_model) else: process_audio(npy, os.path.join(vocoder_output_dir, 'vocal.mp3'), False, None, vocal_decoder, codec_model) return [ os.path.join(vocoder_output_dir, 'instrumental.mp3'), os.path.join(vocoder_output_dir, 'vocal.mp3') ] # Gradio UI with gr.Blocks(title="AI Music Generation") as demo: gr.Markdown("# 🎵 AI Music Generation Pipeline") with gr.Row(): with gr.Column(): genre_input = gr.Textbox(label="Genre Tags", placeholder="e.g., Pop, Happy, Female Vocal") lyrics_input = gr.Textbox(label="Lyrics", lines=10, placeholder="Enter lyrics with segments...") use_audio_prompt = gr.Checkbox(label="Use Audio Prompt") audio_input = gr.Audio(label="Reference Audio", type="filepath", visible=False) start_time = gr.Number(label="Start Time (sec)", value=0.0, visible=False) end_time = gr.Number(label="End Time (sec)", value=30.0, visible=False) generate_btn = gr.Button("Generate Music", variant="primary") with gr.Column(): vocal_output = gr.Audio(label="Vocal Track", interactive=False) inst_output = gr.Audio(label="Instrumental Track", interactive=False) use_audio_prompt.change( lambda x: [gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)], inputs=use_audio_prompt, outputs=[audio_input, start_time, end_time] ) generate_btn.click( generate_music, inputs=[genre_input, lyrics_input, use_audio_prompt, audio_input, start_time, end_time], outputs=[vocal_output, inst_output] ) if __name__ == "__main__": demo.launch()