Spaces:
Runtime error
Runtime error
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import base64 | |
import torch | |
import torchaudio | |
from einops import rearrange | |
from stable_audio_tools import get_pretrained_model | |
from stable_audio_tools.inference.generation import generate_diffusion_cond | |
from diffusers import DiffusionPipeline | |
from huggingface_hub import InferenceClient, cached_download, hf_hub_url | |
from huggingface_hub import HfApi | |
import os | |
from typing import List, Dict | |
# Authentication | |
client = InferenceClient("meta-llama/Meta-Llama-3.1-8B-Instruct", token=os.environ.get("api_key")) | |
# Load models | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0") | |
sample_rate = model_config["sample_rate"] | |
sample_size = model_config["sample_size"] | |
model = model.to(device) | |
pipeline = DiffusionPipeline.from_pretrained("fluently/Fluently-XL-v2") | |
pipeline.load_lora_weights("ehristoforu/dalle-3-xl-v2") | |
# --- Hugging Face Spaces Storage --- | |
api = HfApi() | |
repo_id = "kvikontent/suno-ai" # Replace with your Hugging Face repository ID | |
# --- Global Variables --- | |
generated_songs = {} | |
# Function to generate audio (Requires GPU) | |
def generate_audio(prompt: str) -> List[bytes]: | |
"""Generates music, image, and names a song.""" | |
# --- Audio Generation --- | |
conditioning = [{ | |
"prompt": prompt, | |
}] | |
output = generate_diffusion_cond( | |
model, | |
conditioning=conditioning, | |
sample_size=sample_size, | |
device=device | |
) | |
output = rearrange(output, "b d n -> d (b n)") | |
# Peak normalize, clip, convert to int16, and save to file | |
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() | |
# Save audio to memory | |
buffer = BytesIO() | |
torchaudio.save(buffer, output, sample_rate) | |
audio_data = buffer.getvalue() | |
# --- Image Generation --- | |
image = pipeline(prompt).images[0] | |
buffer = BytesIO() | |
image.save(buffer, format='png') | |
image_data = buffer.getvalue() | |
# --- Name Generation --- | |
for message in client.chat_completion( | |
messages=[{"role": "user", "content": "Name the song based on this prompt: " + prompt}], | |
max_tokens=500, | |
stream=True, | |
): | |
song_name = message.choices[0].delta.content | |
return audio_data, image_data, song_name | |
# Function to download generated audio and image | |
def download_audio_image(audio_data, image_data, song_name): | |
"""Downloads generated audio and image.""" | |
audio_bytes = base64.b64encode(audio_data).decode('utf-8') | |
image_bytes = base64.b64encode(image_data).decode('utf-8') | |
audio_url = f"data:audio/wav;base64,{audio_bytes}" | |
image_url = f"data:image/png;base64,{image_bytes}" | |
return audio_url, image_url, song_name | |
# Function to make a song public | |
def make_public(song_id, audio_data, image_data, song_name, user_id): | |
"""Makes a song public.""" | |
generated_songs[song_id]["public"] = True | |
# Save the song data to Hugging Face Spaces | |
api.upload_file( | |
path="audio.wav", | |
path_in_repo=f"songs/{song_id}/audio.wav", | |
repo_id=repo_id, | |
repo_type="space", | |
data=audio_data | |
) | |
api.upload_file( | |
path="image.png", | |
path_in_repo=f"songs/{song_id}/image.png", | |
repo_id=repo_id, | |
repo_type="space", | |
data=image_data | |
) | |
# Save the song name as a text file | |
with open(f"song_name.txt", "w") as f: | |
f.write(song_name) | |
api.upload_file( | |
path="song_name.txt", | |
path_in_repo=f"songs/{song_id}/song_name.txt", | |
repo_id=repo_id, | |
repo_type="space", | |
) | |
return generated_songs | |
# Function to fetch songs from Hugging Face Spaces | |
def fetch_songs(user_id=None): | |
"""Fetches songs from Hugging Face Spaces.""" | |
songs = {} | |
files = api.list_repo_files(repo_id=repo_id, repo_type="space") | |
for file in files: | |
if file["path"].startswith("songs"): | |
song_id = file["path"].split("/")[1] | |
if song_id not in songs: | |
songs[song_id] = {} | |
if "audio.wav" in file["path"]: | |
# Fetch audio data | |
audio_data = api.download_file(repo_id=repo_id, repo_type="space", revision="main", path=file["path"]) | |
songs[song_id]["audio"] = audio_data | |
if "image.png" in file["path"]: | |
# Fetch image data | |
image_data = api.download_file(repo_id=repo_id, repo_type="space", revision="main", path=file["path"]) | |
songs[song_id]["image"] = image_data | |
if "song_name.txt" in file["path"]: | |
# Fetch song name data | |
with open("song_name.txt", "wb") as f: | |
f.write(api.download_file(repo_id=repo_id, repo_type="space", revision="main", path=file["path"])) | |
with open("song_name.txt", "r") as f: | |
song_name = f.read() | |
songs[song_id]["name"] = song_name | |
# Extract the public/private status and user ID from the file name (if available) | |
# ... (Implement logic here based on how you store this information) | |
# ... | |
return songs | |
# --- User Interface --- | |
with gr.Blocks() as demo: | |
gr.Markdown("## Neon Synth Music Generator") | |
# Input area | |
prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., 128 BPM tech house drum loop") | |
generate_button = gr.Button("Generate") | |
# Output area | |
generated_audio = gr.Audio(label="Generated Audio", playable=True, source="upload") | |
generated_image = gr.Image(label="Generated Image") | |
song_name = gr.Textbox(label="Song Name") | |
make_public_button = gr.Button("Make Public") | |
# User authentication | |
login_button = gr.Button("Login") | |
logout_button = gr.Button("Logout", visible=False) | |
user_name = gr.Textbox(label="Username", interactive=False, visible=False) | |
# Feed area | |
public_feed = gr.Gallery(label="Public Feed", show_label=False, elem_id="public-feed") | |
user_feed = gr.Gallery(label="Your Feed", show_label=False, elem_id="user-feed") | |
# --- Event Handlers --- | |
generate_button.click(fn=generate_audio, inputs=prompt_input, outputs=[generated_audio, generated_image, song_name]) | |
make_public_button.click(fn=make_public, inputs=[gr.State(generated_songs), generated_audio, generated_image, song_name, gr.State(user_name)], outputs=[gr.State(generated_songs)], show_error=False) | |
login_button.click(fn=lambda: "YourUsername", inputs=[], outputs=[user_name], show_error=False) | |
logout_button.click(fn=lambda: "", inputs=[], outputs=[user_name], show_error=False) | |
login_button.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=login_button, show_error=False) | |
login_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=logout_button, show_error=False) | |
login_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=user_name, show_error=False) | |
logout_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=login_button, show_error=False) | |
logout_button.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=logout_button, show_error=False) | |
logout_button.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=user_name, show_error=False) | |
# --- Update the feed --- | |
generated_audio.change(fn=download_audio_image, inputs=[generated_audio, generated_image, song_name], outputs=[generated_audio, generated_image, song_name], show_error=False) | |
generated_audio.change( | |
fn=lambda audio_data, image_data, song_name, user_name: [ | |
{"audio": audio_data, "image": image_data, "name": song_name, "public": False, "user": user_name} | |
], | |
inputs=[generated_audio, generated_image, song_name, user_name], | |
outputs=[gr.State(generated_songs)], | |
show_error=False, | |
) | |
# Refresh the feed when a new song is added | |
generated_songs.change( | |
fn=lambda generated_songs: [ | |
[gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in generated_songs.values() if s["public"]], | |
[gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in generated_songs.values() if not s["public"] and s["user"] == user_name] | |
], | |
inputs=[gr.State(generated_songs)], | |
outputs=[public_feed, user_feed], | |
show_error=False, | |
) | |
# Fetch and display the feeds | |
demo.load( | |
fn=lambda: [ | |
[gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in fetch_songs().values() if s["public"]], | |
[gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in fetch_songs(user_name).values() if not s["public"]] | |
], | |
outputs=[public_feed, user_feed], | |
show_error=False, | |
) | |
# --- Layout --- | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input | |
generate_button | |
login_button | |
logout_button | |
user_name | |
with gr.Column(): | |
generated_audio | |
generated_image | |
song_name | |
make_public_button | |
with gr.Row(): | |
with gr.Column(): | |
public_feed | |
with gr.Column(): | |
user_feed | |
# Run the Gradio interface | |
demo.launch() |