Kvikontent commited on
Commit
76b2ea7
·
verified ·
1 Parent(s): 1ea7e3c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +245 -0
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import base64
4
+ import torch
5
+ import torchaudio
6
+ from einops import rearrange
7
+ from stable_audio_tools import get_pretrained_model
8
+ from stable_audio_tools.inference.generation import generate_diffusion_cond
9
+ from diffusers import DiffusionPipeline
10
+ from huggingface_hub import InferenceClient, cached_download, hf_hub_url
11
+ from huggingface_hub import HfApi
12
+
13
+ import os
14
+ from typing import List, Dict
15
+
16
+ # Authentication
17
+ client = InferenceClient("meta-llama/Meta-Llama-3.1-8B-Instruct", token=os.environ.get("api_key"))
18
+
19
+ # Load models
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0")
22
+ sample_rate = model_config["sample_rate"]
23
+ sample_size = model_config["sample_size"]
24
+ model = model.to(device)
25
+
26
+ pipeline = DiffusionPipeline.from_pretrained("fluently/Fluently-XL-v2")
27
+ pipeline.load_lora_weights("ehristoforu/dalle-3-xl-v2")
28
+
29
+ # --- Hugging Face Spaces Storage ---
30
+ api = HfApi()
31
+ repo_id = "kvikontent/suno-ai" # Replace with your Hugging Face repository ID
32
+
33
+ # --- Global Variables ---
34
+ generated_songs = {}
35
+
36
+ # Function to generate audio (Requires GPU)
37
+ @gr.blocks
38
+ @spaces.GPU
39
+ def generate_audio(prompt: str) -> List[bytes]:
40
+ """Generates music, image, and names a song."""
41
+ # --- Audio Generation ---
42
+ conditioning = [{
43
+ "prompt": prompt,
44
+ }]
45
+
46
+ output = generate_diffusion_cond(
47
+ model,
48
+ conditioning=conditioning,
49
+ sample_size=sample_size,
50
+ device=device
51
+ )
52
+
53
+ output = rearrange(output, "b d n -> d (b n)")
54
+
55
+ # Peak normalize, clip, convert to int16, and save to file
56
+ output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
57
+
58
+ # Save audio to memory
59
+ buffer = BytesIO()
60
+ torchaudio.save(buffer, output, sample_rate)
61
+ audio_data = buffer.getvalue()
62
+
63
+ # --- Image Generation ---
64
+ image = pipeline(prompt).images[0]
65
+ buffer = BytesIO()
66
+ image.save(buffer, format='png')
67
+ image_data = buffer.getvalue()
68
+
69
+ # --- Name Generation ---
70
+ for message in client.chat_completion(
71
+ messages=[{"role": "user", "content": "Name the song based on this prompt: " + prompt}],
72
+ max_tokens=500,
73
+ stream=True,
74
+ ):
75
+ song_name = message.choices[0].delta.content
76
+
77
+ return audio_data, image_data, song_name
78
+
79
+ # Function to download generated audio and image
80
+ def download_audio_image(audio_data, image_data, song_name):
81
+ """Downloads generated audio and image."""
82
+ audio_bytes = base64.b64encode(audio_data).decode('utf-8')
83
+ image_bytes = base64.b64encode(image_data).decode('utf-8')
84
+
85
+ audio_url = f"data:audio/wav;base64,{audio_bytes}"
86
+ image_url = f"data:image/png;base64,{image_bytes}"
87
+
88
+ return audio_url, image_url, song_name
89
+
90
+ # Function to make a song public
91
+ def make_public(song_id, audio_data, image_data, song_name, user_id):
92
+ """Makes a song public."""
93
+ generated_songs[song_id]["public"] = True
94
+
95
+ # Save the song data to Hugging Face Spaces
96
+ api.upload_file(
97
+ path="audio.wav",
98
+ path_in_repo=f"songs/{song_id}/audio.wav",
99
+ repo_id=repo_id,
100
+ repo_type="space",
101
+ data=audio_data
102
+ )
103
+ api.upload_file(
104
+ path="image.png",
105
+ path_in_repo=f"songs/{song_id}/image.png",
106
+ repo_id=repo_id,
107
+ repo_type="space",
108
+ data=image_data
109
+ )
110
+ # Save the song name as a text file
111
+ with open(f"song_name.txt", "w") as f:
112
+ f.write(song_name)
113
+ api.upload_file(
114
+ path="song_name.txt",
115
+ path_in_repo=f"songs/{song_id}/song_name.txt",
116
+ repo_id=repo_id,
117
+ repo_type="space",
118
+ )
119
+
120
+ return generated_songs
121
+
122
+ # Function to fetch songs from Hugging Face Spaces
123
+ def fetch_songs(user_id=None):
124
+ """Fetches songs from Hugging Face Spaces."""
125
+ songs = {}
126
+ files = api.list_repo_files(repo_id=repo_id, repo_type="space")
127
+ for file in files:
128
+ if file["path"].startswith("songs"):
129
+ song_id = file["path"].split("/")[1]
130
+ if song_id not in songs:
131
+ songs[song_id] = {}
132
+
133
+ if "audio.wav" in file["path"]:
134
+ # Fetch audio data
135
+ audio_data = api.download_file(repo_id=repo_id, repo_type="space", revision="main", path=file["path"])
136
+ songs[song_id]["audio"] = audio_data
137
+
138
+ if "image.png" in file["path"]:
139
+ # Fetch image data
140
+ image_data = api.download_file(repo_id=repo_id, repo_type="space", revision="main", path=file["path"])
141
+ songs[song_id]["image"] = image_data
142
+
143
+ if "song_name.txt" in file["path"]:
144
+ # Fetch song name data
145
+ with open("song_name.txt", "wb") as f:
146
+ f.write(api.download_file(repo_id=repo_id, repo_type="space", revision="main", path=file["path"]))
147
+ with open("song_name.txt", "r") as f:
148
+ song_name = f.read()
149
+ songs[song_id]["name"] = song_name
150
+
151
+ # Extract the public/private status and user ID from the file name (if available)
152
+ # ... (Implement logic here based on how you store this information)
153
+ # ...
154
+
155
+ return songs
156
+
157
+ # --- User Interface ---
158
+ with gr.Blocks() as demo:
159
+ gr.Markdown("## Neon Synth Music Generator")
160
+
161
+ # Input area
162
+ prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., 128 BPM tech house drum loop")
163
+ generate_button = gr.Button("Generate")
164
+
165
+ # Output area
166
+ generated_audio = gr.Audio(label="Generated Audio", playable=True, source="upload")
167
+ generated_image = gr.Image(label="Generated Image")
168
+ song_name = gr.Textbox(label="Song Name")
169
+ make_public_button = gr.Button("Make Public")
170
+
171
+ # User authentication
172
+ login_button = gr.Button("Login")
173
+ logout_button = gr.Button("Logout", visible=False)
174
+ user_name = gr.Textbox(label="Username", interactive=False, visible=False)
175
+
176
+ # Feed area
177
+ public_feed = gr.Gallery(label="Public Feed", show_label=False, elem_id="public-feed")
178
+ user_feed = gr.Gallery(label="Your Feed", show_label=False, elem_id="user-feed")
179
+
180
+ # --- Event Handlers ---
181
+ generate_button.click(fn=generate_audio, inputs=prompt_input, outputs=[generated_audio, generated_image, song_name])
182
+ make_public_button.click(fn=make_public, inputs=[gr.State(generated_songs), generated_audio, generated_image, song_name, gr.State(user_name)], outputs=[gr.State(generated_songs)], show_error=False)
183
+ login_button.click(fn=lambda: "YourUsername", inputs=[], outputs=[user_name], show_error=False)
184
+ logout_button.click(fn=lambda: "", inputs=[], outputs=[user_name], show_error=False)
185
+ login_button.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=login_button, show_error=False)
186
+ login_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=logout_button, show_error=False)
187
+ login_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=user_name, show_error=False)
188
+ logout_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=login_button, show_error=False)
189
+ logout_button.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=logout_button, show_error=False)
190
+ logout_button.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=user_name, show_error=False)
191
+
192
+ # --- Update the feed ---
193
+ generated_audio.change(fn=download_audio_image, inputs=[generated_audio, generated_image, song_name], outputs=[generated_audio, generated_image, song_name], show_error=False)
194
+ generated_audio.change(
195
+ fn=lambda audio_data, image_data, song_name, user_name: [
196
+ {"audio": audio_data, "image": image_data, "name": song_name, "public": False, "user": user_name}
197
+ ],
198
+ inputs=[generated_audio, generated_image, song_name, user_name],
199
+ outputs=[gr.State(generated_songs)],
200
+ show_error=False,
201
+ )
202
+
203
+ # Refresh the feed when a new song is added
204
+ generated_songs.change(
205
+ fn=lambda generated_songs: [
206
+ [gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in generated_songs.values() if s["public"]],
207
+ [gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in generated_songs.values() if not s["public"] and s["user"] == user_name]
208
+ ],
209
+ inputs=[gr.State(generated_songs)],
210
+ outputs=[public_feed, user_feed],
211
+ show_error=False,
212
+ )
213
+
214
+ # Fetch and display the feeds
215
+ demo.load(
216
+ fn=lambda: [
217
+ [gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in fetch_songs().values() if s["public"]],
218
+ [gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in fetch_songs(user_name).values() if not s["public"]]
219
+ ],
220
+ outputs=[public_feed, user_feed],
221
+ show_error=False,
222
+ )
223
+
224
+ # --- Layout ---
225
+ with gr.Row():
226
+ with gr.Column():
227
+ prompt_input
228
+ generate_button
229
+ login_button
230
+ logout_button
231
+ user_name
232
+ with gr.Column():
233
+ generated_audio
234
+ generated_image
235
+ song_name
236
+ make_public_button
237
+
238
+ with gr.Row():
239
+ with gr.Column():
240
+ public_feed
241
+ with gr.Column():
242
+ user_feed
243
+
244
+ # Run the Gradio interface
245
+ demo.launch()