import os import torch import gradio as gr from einops import rearrange, repeat from diffusers import AutoencoderKL from transformers import SpeechT5HifiGan from scipy.io import wavfile import glob import random import numpy as np import re import requests import time import gc # Import necessary functions and classes from utils import load_t5, load_clap from train import RF from constants import build_model # Global variables to store loaded models and resources global_model = None global_t5 = None global_clap = None global_vae = None global_vocoder = None global_diffusion = None current_model_name = None # Set the models directory MODELS_DIR = os.path.join(os.path.dirname(__file__), "models") GENERATIONS_DIR = os.path.join(os.path.dirname(__file__), "generations") def prepare(t5, clip, img, prompt): bs, c, h, w = img.shape if bs == 1 and not isinstance(prompt, str): bs = len(prompt) img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img.shape[0] == 1 and bs > 1: img = repeat(img, "1 ... -> bs ...", bs=bs) img_ids = torch.zeros(h // 2, w // 2, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) if isinstance(prompt, str): prompt = [prompt] # Generate text embeddings txt = t5(prompt) if txt.shape[0] == 1 and bs > 1: txt = repeat(txt, "1 ... -> bs ...", bs=bs) txt_ids = torch.zeros(bs, txt.shape[1], 3) vec = clip(prompt) if vec.shape[0] == 1 and bs > 1: vec = repeat(vec, "1 ... -> bs ...", bs=bs) return img, { "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "y": vec.to(img.device), } def unload_current_model(): global global_model, current_model_name if global_model is not None: del global_model global_model = None current_model_name = None torch.cuda.empty_cache() gc.collect() def load_model(model_name, device, model_url=None): global global_model, current_model_name unload_current_model() if model_url: print(f"Downloading model from URL: {model_url}") response = requests.get(model_url) if response.status_code == 200: model_path = os.path.join(MODELS_DIR, "downloaded_model.pt") with open(model_path, 'wb') as f: f.write(response.content) model_name = "downloaded_model.pt" else: return f"Failed to download model from URL: {model_url}" else: model_path = os.path.join(MODELS_DIR, model_name) if not os.path.exists(model_path): return f"Model file not found: {model_path}" # Determine model size from filename if 'musicflow_b' in model_name: model_size = "base" elif 'musicflow_g' in model_name: model_size = "giant" elif 'musicflow_l' in model_name: model_size = "large" elif 'musicflow_s' in model_name: model_size = "small" else: model_size = "base" # Default to base if unrecognized print(f"Loading {model_size} model: {model_name}") try: start_time = time.time() global_model = build_model(model_size).to(device) state_dict = torch.load(model_path, map_location=device, weights_only=True) global_model.load_state_dict(state_dict['ema'], strict=False) global_model.eval() global_model.model_path = model_path current_model_name = model_name end_time = time.time() load_time = end_time - start_time return f"Successfully loaded model: {model_name} in {load_time:.2f} seconds" except Exception as e: unload_current_model() print(f"Error loading model {model_name}: {str(e)}") return f"Failed to load model: {model_name}. Error: {str(e)}" def load_resources(device): global global_t5, global_clap, global_vae, global_vocoder, global_diffusion try: start_time = time.time() print("Loading T5 and CLAP models...") global_t5 = load_t5(device, max_length=256) global_clap = load_clap(device, max_length=256) print("Loading VAE and vocoder...") global_vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder="vae").to(device) global_vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder="vocoder").to(device) print("Initializing diffusion...") global_diffusion = RF() end_time = time.time() load_time = end_time - start_time print(f"Base resources loaded successfully in {load_time:.2f} seconds!") return f"Resources loaded successfully in {load_time:.2f} seconds!" except Exception as e: print(f"Error loading resources: {str(e)}") return f"Failed to load resources. Error: {str(e)}" def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=1, progress=gr.Progress()): global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion if global_model is None: return "Please select and load a model first.", None if global_t5 is None or global_clap is None or global_vae is None or global_vocoder is None or global_diffusion is None: return "Resources not properly loaded. Please reload the page and try again.", None if seed == 0: seed = random.randint(1, 1000000) print(f"Using seed: {seed}") torch.manual_seed(seed) torch.set_grad_enabled(False) # Ensure we're using CPU if CUDA is not available if device == "cuda" and not torch.cuda.is_available(): print("CUDA is not available. Falling back to CPU.") device = "cpu" # Calculate the number of segments needed for the desired duration segment_duration = 10 # Each segment is 10 seconds num_segments = int(np.ceil(duration / segment_duration)) all_waveforms = [] for i in range(num_segments): progress(i / num_segments, desc=f"Generating segment {i+1}/{num_segments}") # Use the same seed for all segments torch.manual_seed(seed + i) # Add i to slightly vary each segment while maintaining consistency latent_size = (256, 16) conds_txt = [prompt] unconds_txt = ["low quality, gentle"] L = len(conds_txt) init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).to(device) img, conds = prepare(global_t5, global_clap, init_noise, conds_txt) _, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt) # Implement batching for inference images = [] for batch_start in range(0, img.shape[0], batch_size): batch_end = min(batch_start + batch_size, img.shape[0]) batch_img = img[batch_start:batch_end] batch_conds = {k: v[batch_start:batch_end] for k, v in conds.items()} batch_unconds = {k: v[batch_start:batch_end] for k, v in unconds.items()} with torch.no_grad(): batch_images = global_diffusion.sample_with_xps( global_model, batch_img, conds=batch_conds, null_cond=batch_unconds, sample_steps=steps, cfg=cfg_scale ) images.append(batch_images[-1]) images = torch.cat(images, dim=0) images = rearrange( images, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=128, w=8, ph=2, pw=2,) latents = 1 / global_vae.config.scaling_factor * images mel_spectrogram = global_vae.decode(latents).sample x_i = mel_spectrogram[0] if x_i.dim() == 4: x_i = x_i.squeeze(1) waveform = global_vocoder(x_i) waveform = waveform[0].cpu().float().detach().numpy() all_waveforms.append(waveform) # Clear some memory after each segment del images, latents, mel_spectrogram, x_i torch.cuda.empty_cache() gc.collect() # Concatenate all waveforms final_waveform = np.concatenate(all_waveforms) # Trim to exact duration sample_rate = 16000 final_waveform = final_waveform[:int(duration * sample_rate)] progress(0.9, desc="Saving audio file") # Create 'generations' folder os.makedirs(GENERATIONS_DIR, exist_ok=True) # Generate filename prompt_part = re.sub(r'[^\w\s-]', '', prompt)[:10].strip().replace(' ', '_') model_name = os.path.splitext(os.path.basename(global_model.model_path))[0] model_suffix = '_mf_b' if model_name == 'musicflow_b' else f'_{model_name}' base_filename = f"{prompt_part}_{seed}{model_suffix}" output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}.wav") # Check if file exists and add numerical suffix if needed counter = 1 while os.path.exists(output_path): output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}_{counter}.wav") counter += 1 wavfile.write(output_path, sample_rate, final_waveform) progress(1.0, desc="Audio generation complete") return f"Generated with seed: {seed}", output_path # Get list of .pt files in the models directory model_files = glob.glob(os.path.join(MODELS_DIR, "*.pt")) model_choices = [os.path.basename(f) for f in model_files] # Ensure we have at least one model if not model_choices: print(f"No models found in the models directory: {MODELS_DIR}") print("Available files in the directory:") print(os.listdir(MODELS_DIR)) model_choices = ["No models available"] # Set default model default_model = 'musicflow_b.pt' if 'musicflow_b.pt' in model_choices else model_choices[0] # Set up dark grey theme theme = gr.themes.Monochrome( primary_hue="gray", secondary_hue="gray", neutral_hue="gray", radius_size=gr.themes.sizes.radius_sm, ) # Gradio Interface with gr.Blocks(theme=theme) as iface: gr.Markdown( """
Generate music based on text prompts using FluxMusic model.
Feel free to clone this space and run on GPU locally or on Hugging Face.