suno-ai / app.py
Kvikontent's picture
Create app.py
76b2ea7 verified
import gradio as gr
from huggingface_hub import InferenceClient
import base64
import torch
import torchaudio
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond
from diffusers import DiffusionPipeline
from huggingface_hub import InferenceClient, cached_download, hf_hub_url
from huggingface_hub import HfApi
import os
from typing import List, Dict
# Authentication
client = InferenceClient("meta-llama/Meta-Llama-3.1-8B-Instruct", token=os.environ.get("api_key"))
# Load models
device = "cuda" if torch.cuda.is_available() else "cpu"
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0")
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]
model = model.to(device)
pipeline = DiffusionPipeline.from_pretrained("fluently/Fluently-XL-v2")
pipeline.load_lora_weights("ehristoforu/dalle-3-xl-v2")
# --- Hugging Face Spaces Storage ---
api = HfApi()
repo_id = "kvikontent/suno-ai" # Replace with your Hugging Face repository ID
# --- Global Variables ---
generated_songs = {}
# Function to generate audio (Requires GPU)
@gr.blocks
@spaces.GPU
def generate_audio(prompt: str) -> List[bytes]:
"""Generates music, image, and names a song."""
# --- Audio Generation ---
conditioning = [{
"prompt": prompt,
}]
output = generate_diffusion_cond(
model,
conditioning=conditioning,
sample_size=sample_size,
device=device
)
output = rearrange(output, "b d n -> d (b n)")
# Peak normalize, clip, convert to int16, and save to file
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
# Save audio to memory
buffer = BytesIO()
torchaudio.save(buffer, output, sample_rate)
audio_data = buffer.getvalue()
# --- Image Generation ---
image = pipeline(prompt).images[0]
buffer = BytesIO()
image.save(buffer, format='png')
image_data = buffer.getvalue()
# --- Name Generation ---
for message in client.chat_completion(
messages=[{"role": "user", "content": "Name the song based on this prompt: " + prompt}],
max_tokens=500,
stream=True,
):
song_name = message.choices[0].delta.content
return audio_data, image_data, song_name
# Function to download generated audio and image
def download_audio_image(audio_data, image_data, song_name):
"""Downloads generated audio and image."""
audio_bytes = base64.b64encode(audio_data).decode('utf-8')
image_bytes = base64.b64encode(image_data).decode('utf-8')
audio_url = f"data:audio/wav;base64,{audio_bytes}"
image_url = f"data:image/png;base64,{image_bytes}"
return audio_url, image_url, song_name
# Function to make a song public
def make_public(song_id, audio_data, image_data, song_name, user_id):
"""Makes a song public."""
generated_songs[song_id]["public"] = True
# Save the song data to Hugging Face Spaces
api.upload_file(
path="audio.wav",
path_in_repo=f"songs/{song_id}/audio.wav",
repo_id=repo_id,
repo_type="space",
data=audio_data
)
api.upload_file(
path="image.png",
path_in_repo=f"songs/{song_id}/image.png",
repo_id=repo_id,
repo_type="space",
data=image_data
)
# Save the song name as a text file
with open(f"song_name.txt", "w") as f:
f.write(song_name)
api.upload_file(
path="song_name.txt",
path_in_repo=f"songs/{song_id}/song_name.txt",
repo_id=repo_id,
repo_type="space",
)
return generated_songs
# Function to fetch songs from Hugging Face Spaces
def fetch_songs(user_id=None):
"""Fetches songs from Hugging Face Spaces."""
songs = {}
files = api.list_repo_files(repo_id=repo_id, repo_type="space")
for file in files:
if file["path"].startswith("songs"):
song_id = file["path"].split("/")[1]
if song_id not in songs:
songs[song_id] = {}
if "audio.wav" in file["path"]:
# Fetch audio data
audio_data = api.download_file(repo_id=repo_id, repo_type="space", revision="main", path=file["path"])
songs[song_id]["audio"] = audio_data
if "image.png" in file["path"]:
# Fetch image data
image_data = api.download_file(repo_id=repo_id, repo_type="space", revision="main", path=file["path"])
songs[song_id]["image"] = image_data
if "song_name.txt" in file["path"]:
# Fetch song name data
with open("song_name.txt", "wb") as f:
f.write(api.download_file(repo_id=repo_id, repo_type="space", revision="main", path=file["path"]))
with open("song_name.txt", "r") as f:
song_name = f.read()
songs[song_id]["name"] = song_name
# Extract the public/private status and user ID from the file name (if available)
# ... (Implement logic here based on how you store this information)
# ...
return songs
# --- User Interface ---
with gr.Blocks() as demo:
gr.Markdown("## Neon Synth Music Generator")
# Input area
prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., 128 BPM tech house drum loop")
generate_button = gr.Button("Generate")
# Output area
generated_audio = gr.Audio(label="Generated Audio", playable=True, source="upload")
generated_image = gr.Image(label="Generated Image")
song_name = gr.Textbox(label="Song Name")
make_public_button = gr.Button("Make Public")
# User authentication
login_button = gr.Button("Login")
logout_button = gr.Button("Logout", visible=False)
user_name = gr.Textbox(label="Username", interactive=False, visible=False)
# Feed area
public_feed = gr.Gallery(label="Public Feed", show_label=False, elem_id="public-feed")
user_feed = gr.Gallery(label="Your Feed", show_label=False, elem_id="user-feed")
# --- Event Handlers ---
generate_button.click(fn=generate_audio, inputs=prompt_input, outputs=[generated_audio, generated_image, song_name])
make_public_button.click(fn=make_public, inputs=[gr.State(generated_songs), generated_audio, generated_image, song_name, gr.State(user_name)], outputs=[gr.State(generated_songs)], show_error=False)
login_button.click(fn=lambda: "YourUsername", inputs=[], outputs=[user_name], show_error=False)
logout_button.click(fn=lambda: "", inputs=[], outputs=[user_name], show_error=False)
login_button.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=login_button, show_error=False)
login_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=logout_button, show_error=False)
login_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=user_name, show_error=False)
logout_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=login_button, show_error=False)
logout_button.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=logout_button, show_error=False)
logout_button.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=user_name, show_error=False)
# --- Update the feed ---
generated_audio.change(fn=download_audio_image, inputs=[generated_audio, generated_image, song_name], outputs=[generated_audio, generated_image, song_name], show_error=False)
generated_audio.change(
fn=lambda audio_data, image_data, song_name, user_name: [
{"audio": audio_data, "image": image_data, "name": song_name, "public": False, "user": user_name}
],
inputs=[generated_audio, generated_image, song_name, user_name],
outputs=[gr.State(generated_songs)],
show_error=False,
)
# Refresh the feed when a new song is added
generated_songs.change(
fn=lambda generated_songs: [
[gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in generated_songs.values() if s["public"]],
[gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in generated_songs.values() if not s["public"] and s["user"] == user_name]
],
inputs=[gr.State(generated_songs)],
outputs=[public_feed, user_feed],
show_error=False,
)
# Fetch and display the feeds
demo.load(
fn=lambda: [
[gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in fetch_songs().values() if s["public"]],
[gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in fetch_songs(user_name).values() if not s["public"]]
],
outputs=[public_feed, user_feed],
show_error=False,
)
# --- Layout ---
with gr.Row():
with gr.Column():
prompt_input
generate_button
login_button
logout_button
user_name
with gr.Column():
generated_audio
generated_image
song_name
make_public_button
with gr.Row():
with gr.Column():
public_feed
with gr.Column():
user_feed
# Run the Gradio interface
demo.launch()