Spaces:
Running
Running
| import gradio as gr | |
| import requests | |
| import base64 | |
| import io | |
| import time | |
| import soundfile as sf | |
| # ββ Backend config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| API_URL = "https://bustled-hertha-unprojective.ngrok-free.dev/generate" | |
| BASE_URL = API_URL.replace("/generate", "") | |
| HEADERS = { | |
| "ngrok-skip-browser-warning": "true", | |
| "Content-Type": "application/json", | |
| } | |
| CFG = 7.0 # fixed optimal value from paper β not exposed in UI | |
| # ββ Generation function ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate(prompt, duration, steps): | |
| try: | |
| print(f"π€ Sending request β prompt='{prompt[:60]}...' duration={duration}s steps={steps}") | |
| res = requests.post( | |
| API_URL, | |
| json={ | |
| "prompt": prompt, | |
| "duration": float(duration), | |
| "steps": int(steps), | |
| "cfg_scale": CFG, | |
| }, | |
| headers=HEADERS, | |
| timeout=30, | |
| ) | |
| print(f"STATUS: {res.status_code}") | |
| if res.status_code != 200: | |
| print(f"β Backend error: {res.text[:500]}") | |
| return None | |
| if not res.text: | |
| print("β Empty response from backend") | |
| return None | |
| try: | |
| data = res.json() | |
| except Exception: | |
| print(f"β Invalid JSON (got HTML?): {res.text[:300]}") | |
| return None | |
| # ββ Sync backend: audio returned immediately βββββββββββββββββββββββββββ | |
| if "audio" in data: | |
| audio_bytes = base64.b64decode(data["audio"]) | |
| audio, sr = sf.read(io.BytesIO(audio_bytes)) | |
| print(f"β Got audio instantly: {len(audio)/sr:.2f}s @ {sr}Hz") | |
| return sr, audio | |
| # ββ Async backend: job_id returned, poll for result ββββββββββββββββββββ | |
| job_id = data.get("job_id") | |
| if not job_id: | |
| print(f"β No audio and no job_id in response: {data}") | |
| return None | |
| print(f"β³ Job queued: {job_id} β polling for result...") | |
| POLL_INTERVAL = 5 # seconds between each poll | |
| MAX_WAIT = 600 # 10 minutes max total wait | |
| for elapsed in range(0, MAX_WAIT, POLL_INTERVAL): | |
| time.sleep(POLL_INTERVAL) | |
| print(f"π Polling... ({elapsed + POLL_INTERVAL}s elapsed)") | |
| try: | |
| poll = requests.get( | |
| f"{BASE_URL}/result/{job_id}", | |
| headers=HEADERS, | |
| timeout=10, | |
| ) | |
| except Exception as e: | |
| print(f"β οΈ Poll request failed: {e} β retrying") | |
| continue | |
| if poll.status_code != 200: | |
| print(f"β οΈ Poll returned {poll.status_code} β retrying") | |
| continue | |
| job = poll.json() | |
| status = job.get("status") | |
| print(f" status = {status}") | |
| if status == "error": | |
| print(f"β Job failed: {job.get('error')}") | |
| return None | |
| if status == "done" and "audio" in job: | |
| audio_bytes = base64.b64decode(job["audio"]) | |
| audio, sr = sf.read(io.BytesIO(audio_bytes)) | |
| print(f"β Got audio: {len(audio)/sr:.2f}s @ {sr}Hz") | |
| return sr, audio | |
| if status not in ("pending", "processing", "running", "done"): | |
| print(f"β οΈ Unknown status '{status}' β continuing to poll") | |
| print("β Timed out after 10 minutes") | |
| return None | |
| except requests.exceptions.Timeout: | |
| print("β Initial request timed out β Colab backend may be busy") | |
| return None | |
| except Exception as e: | |
| print(f"β Request failed: {e}") | |
| return None | |
| # ββ UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="AutoMix AI π΅") as demo: | |
| gr.Markdown("# π΅ AutoMix AI Beat Generator") | |
| gr.Markdown("Generate AI beats using a diffusion model fine-tuned on trap/rap/R&B π") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_in = gr.Textbox( | |
| label="π§ Prompt", | |
| placeholder="A dark trap beat at 140 BPM in C minor, featuring 808 bass and synth bells.", | |
| lines=4, | |
| ) | |
| duration_in = gr.Slider( | |
| minimum=5, | |
| maximum=95, | |
| value=30, | |
| step=1, | |
| label="β± Duration (seconds)", | |
| ) | |
| steps_in = gr.Slider( | |
| minimum=20, | |
| maximum=300, | |
| value=100, | |
| step=10, | |
| label="βοΈ Diffusion Steps (more = better quality, slower)", | |
| ) | |
| gr.Markdown( | |
| "> β οΈ Durations above 47s take significantly longer to generate. " | |
| "Keep steps β€ 150 for beats over 60s to avoid timeouts." | |
| ) | |
| btn = gr.Button("π Generate Beat", variant="primary") | |
| with gr.Column(): | |
| output = gr.Audio(label="π΅ Generated Beat", type="numpy") | |
| btn.click( | |
| fn=generate, | |
| inputs=[prompt_in, duration_in, steps_in], | |
| outputs=output, | |
| ) | |
| demo.launch() |