| | |
| | |
| |
|
| | import os |
| | import sys |
| | import tempfile |
| | import torch |
| | import gradio as gr |
| | import numpy as np |
| | import torchaudio |
| | from huggingface_hub import snapshot_download |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| |
|
| | |
| | |
| | |
| |
|
| | MODEL_ID = "bolshyC/Muse-0.6b" |
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| | SAMPLE_RATE = 48000 |
| |
|
| | |
| | os.environ["HF_HOME"] = "/tmp/hf" |
| | os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers" |
| |
|
| | |
| | |
| | |
| |
|
| | print("Loading Muse-0.6b...") |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | MODEL_ID, |
| | trust_remote_code=True |
| | ) |
| |
|
| | model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_ID, |
| | trust_remote_code=True, |
| | torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
| | device_map="auto" if DEVICE == "cuda" else None, |
| | ) |
| | model.eval() |
| | print("Muse loaded.") |
| |
|
| | |
| | |
| | |
| |
|
| | print("Loading MuCodec / AudioLDM...") |
| |
|
| | sys.path.insert(0, "./MuCodec") |
| | from MuCodec.model import PromptCondAudioDiffusion |
| | from MuCodec.tools.get_melvaehifigan48k import build_pretrained_models |
| |
|
| | |
| | audioldm_dir = snapshot_download( |
| | "haoheliu/audioldm_48k", |
| | local_dir="/tmp/audioldm", |
| | local_dir_use_symlinks=False |
| | ) |
| | audioldm_path = os.path.join(audioldm_dir, "audioldm_48k.pth") |
| |
|
| | vae, stft = build_pretrained_models(audioldm_path) |
| | vae = vae.to(DEVICE).eval() |
| |
|
| | mucodec = PromptCondAudioDiffusion( |
| | num_channels=32, |
| | unet_model_name=None, |
| | unet_model_config_path="./MuCodec/configs/models/transformer2D.json", |
| | snr_gamma=None, |
| | ) |
| |
|
| | ckpt = torch.load("./MuCodec/ckpt/mucodec.pt", map_location="cpu") |
| | mucodec.load_state_dict(ckpt, strict=False) |
| | mucodec = mucodec.to(DEVICE).eval() |
| |
|
| | print("MuCodec loaded.") |
| |
|
| | |
| | |
| | |
| |
|
| | def extract_audio_tokens(text: str): |
| | if "<|audio_0|>" not in text: |
| | return None |
| | start = text.find("<|audio_0|>") + len("<|audio_0|>") |
| | end = text.find("<|audio_1|>") |
| | tokens = [int(x) for x in text[start:end].split() if x.isdigit()] |
| | if not tokens: |
| | return None |
| | return torch.tensor(tokens).unsqueeze(0).unsqueeze(0) |
| |
|
| | |
| | |
| | |
| |
|
| | def generate(prompt): |
| | if not prompt.strip(): |
| | return None, "Empty prompt" |
| |
|
| | inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
| |
|
| | with torch.no_grad(): |
| | out = model.generate( |
| | **inputs, |
| | max_new_tokens=1024, |
| | do_sample=False |
| | ) |
| |
|
| | text = tokenizer.decode(out[0], skip_special_tokens=False) |
| | codes = extract_audio_tokens(text) |
| | if codes is None: |
| | return None, "Failed to parse audio tokens" |
| |
|
| | codes = codes.to(DEVICE) |
| |
|
| | |
| | latents = mucodec.inference_codes( |
| | [codes[:, :, :1024]], |
| | torch.zeros([1, 32, 1, 32], device=DEVICE), |
| | torch.randn(1, 32, 512, 32, device=DEVICE), |
| | latent_length=512, |
| | first_latent_length=0, |
| | additional_feats=[], |
| | guidance_scale=1.0, |
| | num_steps=10, |
| | disable_progress=True, |
| | scenario="other_seg" |
| | ) |
| |
|
| | mel = vae.decode_first_stage(latents.float()) |
| | wav = vae.decode_to_waveform(mel) |
| |
|
| | with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: |
| | out_path = f.name |
| |
|
| | torchaudio.save(out_path, torch.from_numpy(wav), SAMPLE_RATE) |
| | return out_path, "Done" |
| |
|
| | |
| | |
| | |
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown("# Muse-0.6b (Experimental HF Space)") |
| | prompt = gr.Textbox(label="Prompt") |
| | btn = gr.Button("Generate") |
| | audio = gr.Audio(type="filepath") |
| | status = gr.Textbox() |
| |
|
| | btn.click(generate, prompt, [audio, status]) |
| |
|
| | demo.launch() |
| |
|