StableBeatAudio / app.py
bharatverse11's picture
Update app.py
044def4 verified
import gradio as gr
import requests
import base64
import io
import time
import soundfile as sf
# ── Backend config ─────────────────────────────────────────────────────────────
API_URL = "https://bustled-hertha-unprojective.ngrok-free.dev/generate"
BASE_URL = API_URL.replace("/generate", "")
HEADERS = {
"ngrok-skip-browser-warning": "true",
"Content-Type": "application/json",
}
CFG = 7.0 # fixed optimal value from paper β€” not exposed in UI
# ── Generation function ────────────────────────────────────────────────────────
def generate(prompt, duration, steps):
try:
print(f"πŸ“€ Sending request β†’ prompt='{prompt[:60]}...' duration={duration}s steps={steps}")
res = requests.post(
API_URL,
json={
"prompt": prompt,
"duration": float(duration),
"steps": int(steps),
"cfg_scale": CFG,
},
headers=HEADERS,
timeout=30,
)
print(f"STATUS: {res.status_code}")
if res.status_code != 200:
print(f"❌ Backend error: {res.text[:500]}")
return None
if not res.text:
print("❌ Empty response from backend")
return None
try:
data = res.json()
except Exception:
print(f"❌ Invalid JSON (got HTML?): {res.text[:300]}")
return None
# ── Sync backend: audio returned immediately ───────────────────────────
if "audio" in data:
audio_bytes = base64.b64decode(data["audio"])
audio, sr = sf.read(io.BytesIO(audio_bytes))
print(f"βœ… Got audio instantly: {len(audio)/sr:.2f}s @ {sr}Hz")
return sr, audio
# ── Async backend: job_id returned, poll for result ────────────────────
job_id = data.get("job_id")
if not job_id:
print(f"❌ No audio and no job_id in response: {data}")
return None
print(f"⏳ Job queued: {job_id} β€” polling for result...")
POLL_INTERVAL = 5 # seconds between each poll
MAX_WAIT = 600 # 10 minutes max total wait
for elapsed in range(0, MAX_WAIT, POLL_INTERVAL):
time.sleep(POLL_INTERVAL)
print(f"πŸ”„ Polling... ({elapsed + POLL_INTERVAL}s elapsed)")
try:
poll = requests.get(
f"{BASE_URL}/result/{job_id}",
headers=HEADERS,
timeout=10,
)
except Exception as e:
print(f"⚠️ Poll request failed: {e} β€” retrying")
continue
if poll.status_code != 200:
print(f"⚠️ Poll returned {poll.status_code} β€” retrying")
continue
job = poll.json()
status = job.get("status")
print(f" status = {status}")
if status == "error":
print(f"❌ Job failed: {job.get('error')}")
return None
if status == "done" and "audio" in job:
audio_bytes = base64.b64decode(job["audio"])
audio, sr = sf.read(io.BytesIO(audio_bytes))
print(f"βœ… Got audio: {len(audio)/sr:.2f}s @ {sr}Hz")
return sr, audio
if status not in ("pending", "processing", "running", "done"):
print(f"⚠️ Unknown status '{status}' β€” continuing to poll")
print("❌ Timed out after 10 minutes")
return None
except requests.exceptions.Timeout:
print("❌ Initial request timed out β€” Colab backend may be busy")
return None
except Exception as e:
print(f"❌ Request failed: {e}")
return None
# ── UI ─────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="AutoMix AI 🎡") as demo:
gr.Markdown("# 🎡 AutoMix AI Beat Generator")
gr.Markdown("Generate AI beats using a diffusion model fine-tuned on trap/rap/R&B πŸš€")
with gr.Row():
with gr.Column():
prompt_in = gr.Textbox(
label="🎧 Prompt",
placeholder="A dark trap beat at 140 BPM in C minor, featuring 808 bass and synth bells.",
lines=4,
)
duration_in = gr.Slider(
minimum=5,
maximum=95,
value=30,
step=1,
label="⏱ Duration (seconds)",
)
steps_in = gr.Slider(
minimum=20,
maximum=300,
value=100,
step=10,
label="βš™οΈ Diffusion Steps (more = better quality, slower)",
)
gr.Markdown(
"> ⚠️ Durations above 47s take significantly longer to generate. "
"Keep steps ≀ 150 for beats over 60s to avoid timeouts."
)
btn = gr.Button("πŸš€ Generate Beat", variant="primary")
with gr.Column():
output = gr.Audio(label="🎡 Generated Beat", type="numpy")
btn.click(
fn=generate,
inputs=[prompt_in, duration_in, steps_in],
outputs=output,
)
demo.launch()