import gradio as gr from huggingface_hub import InferenceClient import base64 import torch import torchaudio from einops import rearrange from stable_audio_tools import get_pretrained_model from stable_audio_tools.inference.generation import generate_diffusion_cond from diffusers import DiffusionPipeline from huggingface_hub import InferenceClient, cached_download, hf_hub_url from huggingface_hub import HfApi import os from typing import List, Dict # Authentication client = InferenceClient("meta-llama/Meta-Llama-3.1-8B-Instruct", token=os.environ.get("api_key")) # Load models device = "cuda" if torch.cuda.is_available() else "cpu" model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0") sample_rate = model_config["sample_rate"] sample_size = model_config["sample_size"] model = model.to(device) pipeline = DiffusionPipeline.from_pretrained("fluently/Fluently-XL-v2") pipeline.load_lora_weights("ehristoforu/dalle-3-xl-v2") # --- Hugging Face Spaces Storage --- api = HfApi() repo_id = "kvikontent/suno-ai" # Replace with your Hugging Face repository ID # --- Global Variables --- generated_songs = {} # Function to generate audio (Requires GPU) @gr.blocks @spaces.GPU def generate_audio(prompt: str) -> List[bytes]: """Generates music, image, and names a song.""" # --- Audio Generation --- conditioning = [{ "prompt": prompt, }] output = generate_diffusion_cond( model, conditioning=conditioning, sample_size=sample_size, device=device ) output = rearrange(output, "b d n -> d (b n)") # Peak normalize, clip, convert to int16, and save to file output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() # Save audio to memory buffer = BytesIO() torchaudio.save(buffer, output, sample_rate) audio_data = buffer.getvalue() # --- Image Generation --- image = pipeline(prompt).images[0] buffer = BytesIO() image.save(buffer, format='png') image_data = buffer.getvalue() # --- Name Generation --- for message in client.chat_completion( messages=[{"role": "user", "content": "Name the song based on this prompt: " + prompt}], max_tokens=500, stream=True, ): song_name = message.choices[0].delta.content return audio_data, image_data, song_name # Function to download generated audio and image def download_audio_image(audio_data, image_data, song_name): """Downloads generated audio and image.""" audio_bytes = base64.b64encode(audio_data).decode('utf-8') image_bytes = base64.b64encode(image_data).decode('utf-8') audio_url = f"data:audio/wav;base64,{audio_bytes}" image_url = f"data:image/png;base64,{image_bytes}" return audio_url, image_url, song_name # Function to make a song public def make_public(song_id, audio_data, image_data, song_name, user_id): """Makes a song public.""" generated_songs[song_id]["public"] = True # Save the song data to Hugging Face Spaces api.upload_file( path="audio.wav", path_in_repo=f"songs/{song_id}/audio.wav", repo_id=repo_id, repo_type="space", data=audio_data ) api.upload_file( path="image.png", path_in_repo=f"songs/{song_id}/image.png", repo_id=repo_id, repo_type="space", data=image_data ) # Save the song name as a text file with open(f"song_name.txt", "w") as f: f.write(song_name) api.upload_file( path="song_name.txt", path_in_repo=f"songs/{song_id}/song_name.txt", repo_id=repo_id, repo_type="space", ) return generated_songs # Function to fetch songs from Hugging Face Spaces def fetch_songs(user_id=None): """Fetches songs from Hugging Face Spaces.""" songs = {} files = api.list_repo_files(repo_id=repo_id, repo_type="space") for file in files: if file["path"].startswith("songs"): song_id = file["path"].split("/")[1] if song_id not in songs: songs[song_id] = {} if "audio.wav" in file["path"]: # Fetch audio data audio_data = api.download_file(repo_id=repo_id, repo_type="space", revision="main", path=file["path"]) songs[song_id]["audio"] = audio_data if "image.png" in file["path"]: # Fetch image data image_data = api.download_file(repo_id=repo_id, repo_type="space", revision="main", path=file["path"]) songs[song_id]["image"] = image_data if "song_name.txt" in file["path"]: # Fetch song name data with open("song_name.txt", "wb") as f: f.write(api.download_file(repo_id=repo_id, repo_type="space", revision="main", path=file["path"])) with open("song_name.txt", "r") as f: song_name = f.read() songs[song_id]["name"] = song_name # Extract the public/private status and user ID from the file name (if available) # ... (Implement logic here based on how you store this information) # ... return songs # --- User Interface --- with gr.Blocks() as demo: gr.Markdown("## Neon Synth Music Generator") # Input area prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., 128 BPM tech house drum loop") generate_button = gr.Button("Generate") # Output area generated_audio = gr.Audio(label="Generated Audio", playable=True, source="upload") generated_image = gr.Image(label="Generated Image") song_name = gr.Textbox(label="Song Name") make_public_button = gr.Button("Make Public") # User authentication login_button = gr.Button("Login") logout_button = gr.Button("Logout", visible=False) user_name = gr.Textbox(label="Username", interactive=False, visible=False) # Feed area public_feed = gr.Gallery(label="Public Feed", show_label=False, elem_id="public-feed") user_feed = gr.Gallery(label="Your Feed", show_label=False, elem_id="user-feed") # --- Event Handlers --- generate_button.click(fn=generate_audio, inputs=prompt_input, outputs=[generated_audio, generated_image, song_name]) make_public_button.click(fn=make_public, inputs=[gr.State(generated_songs), generated_audio, generated_image, song_name, gr.State(user_name)], outputs=[gr.State(generated_songs)], show_error=False) login_button.click(fn=lambda: "YourUsername", inputs=[], outputs=[user_name], show_error=False) logout_button.click(fn=lambda: "", inputs=[], outputs=[user_name], show_error=False) login_button.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=login_button, show_error=False) login_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=logout_button, show_error=False) login_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=user_name, show_error=False) logout_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=login_button, show_error=False) logout_button.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=logout_button, show_error=False) logout_button.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=user_name, show_error=False) # --- Update the feed --- generated_audio.change(fn=download_audio_image, inputs=[generated_audio, generated_image, song_name], outputs=[generated_audio, generated_image, song_name], show_error=False) generated_audio.change( fn=lambda audio_data, image_data, song_name, user_name: [ {"audio": audio_data, "image": image_data, "name": song_name, "public": False, "user": user_name} ], inputs=[generated_audio, generated_image, song_name, user_name], outputs=[gr.State(generated_songs)], show_error=False, ) # Refresh the feed when a new song is added generated_songs.change( fn=lambda generated_songs: [ [gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in generated_songs.values() if s["public"]], [gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in generated_songs.values() if not s["public"] and s["user"] == user_name] ], inputs=[gr.State(generated_songs)], outputs=[public_feed, user_feed], show_error=False, ) # Fetch and display the feeds demo.load( fn=lambda: [ [gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in fetch_songs().values() if s["public"]], [gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in fetch_songs(user_name).values() if not s["public"]] ], outputs=[public_feed, user_feed], show_error=False, ) # --- Layout --- with gr.Row(): with gr.Column(): prompt_input generate_button login_button logout_button user_name with gr.Column(): generated_audio generated_image song_name make_public_button with gr.Row(): with gr.Column(): public_feed with gr.Column(): user_feed # Run the Gradio interface demo.launch()