Spaces:
Sleeping
Sleeping
| """ | |
| server.py โ HarmoSplit ใใใฏใจใณใ | |
| Hugging Face Spaces ๅฏพๅฟ๏ผใใผใ 7860๏ผ+ Stripe ๆ้กๆฑบๆธ | |
| """ | |
| import os | |
| import sys | |
| import uuid | |
| import json | |
| import shutil | |
| import secrets | |
| import tempfile | |
| import threading | |
| import subprocess | |
| from pathlib import Path | |
| from datetime import datetime, timezone | |
| from flask import ( | |
| Flask, request, jsonify, send_file, | |
| Response, send_from_directory, redirect, url_for | |
| ) | |
| from flask_cors import CORS | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| import app as core | |
| # โโ Flask โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| flask_app = Flask(__name__, static_folder="static", static_url_path="") | |
| CORS(flask_app) | |
| # โโ ่จญๅฎ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| PORT = int(os.environ.get("PORT", 7860)) | |
| # Stripe ใญใผ๏ผHF Spaces Secrets ใพใใฏ .env ใง่จญๅฎ๏ผ | |
| STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY", "") | |
| STRIPE_WEBHOOK_SECRET = os.environ.get("STRIPE_WEBHOOK_SECRET", "") | |
| STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE_ID", "") | |
| # ใขใใชใฎๅ ฌ้ URL๏ผWebhook / Checkout success URL ็จ๏ผ | |
| APP_URL = os.environ.get("APP_URL", f"http://localhost:{PORT}") | |
| # ๆฑบๆธไธ่ฆใขใผใ๏ผStripe ใญใผๆช่จญๅฎใชใ็กๆ้ๆพ๏ผ | |
| FREE_MODE = not bool(STRIPE_SECRET_KEY) | |
| # โโ ๆฐธ็ถในใใฌใผใธ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # HF Spaces ใงใฏ /data ใๆฐธ็ถใใญใผใซใซใงใฏ ./data ใไฝฟ็จใ | |
| DATA_DIR = Path("/data") if Path("/data").exists() else Path("./data") | |
| DATA_DIR.mkdir(parents=True, exist_ok=True) | |
| TOKENS_FILE = DATA_DIR / "tokens.json" | |
| UPLOAD_DIR = DATA_DIR / "uploads" | |
| UPLOAD_DIR.mkdir(parents=True, exist_ok=True) | |
| # โโ ใใผใฏใณ็ฎก็ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| tokens_lock = threading.Lock() | |
| def load_tokens() -> dict: | |
| with tokens_lock: | |
| if TOKENS_FILE.exists(): | |
| try: | |
| return json.loads(TOKENS_FILE.read_text("utf-8")) | |
| except Exception: | |
| pass | |
| return {} | |
| def save_tokens(data: dict): | |
| with tokens_lock: | |
| TOKENS_FILE.write_text(json.dumps(data, indent=2, ensure_ascii=False), "utf-8") | |
| def create_token(customer_id: str, subscription_id: str, email: str) -> str: | |
| token = secrets.token_urlsafe(32) | |
| data = load_tokens() | |
| data[token] = { | |
| "customer_id": customer_id, | |
| "subscription_id": subscription_id, | |
| "email": email, | |
| "created_at": datetime.now(timezone.utc).isoformat(), | |
| "active": True, | |
| } | |
| save_tokens(data) | |
| return token | |
| def is_token_valid(token: str) -> bool: | |
| if FREE_MODE: | |
| return True # ็กๆใขใผใใฏๅธธใซๆๅน | |
| data = load_tokens() | |
| entry = data.get(token) | |
| return bool(entry and entry.get("active")) | |
| def deactivate_token_by_subscription(subscription_id: str): | |
| data = load_tokens() | |
| for info in data.values(): | |
| if info.get("subscription_id") == subscription_id: | |
| info["active"] = False | |
| save_tokens(data) | |
| # โโ ใธใงใ็ฎก็ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| JOBS: dict[str, dict] = {} | |
| JOBS_LOCK = threading.Lock() | |
| def log_progress(job_id: str, message: str, percent: int | None = None): | |
| entry = {"msg": message} | |
| if percent is not None: | |
| entry["pct"] = percent | |
| with JOBS_LOCK: | |
| if job_id in JOBS: | |
| JOBS[job_id]["progress"].append(entry) | |
| def _load_stems_helper(demucs_out: Path, model: str, wav_path: Path): | |
| import soundfile as sf | |
| import numpy as np | |
| stem_dir = demucs_out / model / wav_path.stem | |
| if not stem_dir.exists(): | |
| candidates = list(demucs_out.rglob("*.wav")) | |
| if not candidates: | |
| raise FileNotFoundError("ในใใ ใใกใคใซใ่ฆใคใใใพใใ") | |
| stem_dir = candidates[0].parent | |
| stems = {} | |
| sr = 44100 | |
| for wav_file in sorted(stem_dir.glob("*.wav")): | |
| data, sr = sf.read(str(wav_file), always_2d=True) | |
| stems[wav_file.stem] = data.astype(np.float_()) | |
| return stems, sr | |
| core._load_stems = _load_stems_helper | |
| def process_job(job_id: str, input_path: Path, inst_vol: float, model: str, use_mdx: bool): | |
| tmp_dir = Path(tempfile.mkdtemp(prefix=f"hmsplit_{job_id[:8]}_")) | |
| demucs_out = tmp_dir / "demucs_out" | |
| demucs_out.mkdir(parents=True, exist_ok=True) | |
| try: | |
| with JOBS_LOCK: | |
| JOBS[job_id]["status"] = "processing" | |
| log_progress(job_id, "๐ต ้ณๅฃฐใ่ชญใฟ่พผใฟไธญ...", 5) | |
| wav_path = core.prepare_audio(input_path, tmp_dir) | |
| log_progress(job_id, "๐ค Demucs AI ใง้ณๆบใๅ้ขไธญ...", 15) | |
| cmd = [sys.executable, "-m", "demucs", "-n", model, | |
| "-o", str(demucs_out), str(wav_path)] | |
| proc = subprocess.Popen( | |
| cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, | |
| text=True, encoding="utf-8", errors="replace" | |
| ) | |
| for line in proc.stdout: | |
| line = line.rstrip() | |
| if line.strip(): | |
| log_progress(job_id, line) | |
| proc.wait() | |
| if proc.returncode != 0: | |
| raise RuntimeError("Demucs ๅฆ็ใซๅคฑๆใใพใใ") | |
| log_progress(job_id, "๐ ในใใ ใ่ชญใฟ่พผใฟไธญ...", 55) | |
| stems, sr = core._load_stems(demucs_out, model, wav_path) | |
| if model == "htdemucs_6s" and "vocals" not in stems: | |
| log_progress(job_id, "โ ๏ธ htdemucs ใซใใฉใผใซใใใฏ", 57) | |
| shutil.rmtree(demucs_out); demucs_out.mkdir(parents=True, exist_ok=True) | |
| model = "htdemucs" | |
| subprocess.run([sys.executable, "-m", "demucs", "-n", model, | |
| "-o", str(demucs_out), str(wav_path)], check=True, capture_output=True) | |
| stems, sr = core._load_stems(demucs_out, model, wav_path) | |
| mdx_model_path = None | |
| if use_mdx: | |
| log_progress(job_id, "๐ง UVR MDX-NET ใขใใซใๆบๅไธญ...", 60) | |
| try: | |
| mdx_cache = DATA_DIR / "models" | |
| mdx_cache.mkdir(parents=True, exist_ok=True) | |
| mdx_model_path = core.download_mdx_model(mdx_cache) | |
| log_progress(job_id, "โ UVR MDX-NET ๆบๅๅฎไบ", 65) | |
| except Exception as e: | |
| log_progress(job_id, f"โ ๏ธ MDX ๅๅพๅคฑๆใMid/Side ใซใใฉใผใซใใใฏ: {e}", 65) | |
| log_progress(job_id, "๐๏ธ L/R ใใณใใณใฐ & ใใใฏในๅฆ็ไธญ...", 68) | |
| if mdx_model_path: | |
| log_progress(job_id, "๐ฌ ใชใผใ / ใใใญใณใฐ AI ๅ้ขไธญ... (ๆฐๅใใใใพใ)", 70) | |
| mixed = core.mix_stems(stems, model, inst_vol, mdx_model_path=mdx_model_path, sr=sr) | |
| log_progress(job_id, "๐พ WAV ใๆธใๅบใไธญ...", 95) | |
| import soundfile as sf | |
| output_path = UPLOAD_DIR / f"{job_id}_panned.wav" | |
| sf.write(str(output_path), mixed, sr, subtype="PCM_16") | |
| with JOBS_LOCK: | |
| JOBS[job_id]["status"] = "done" | |
| JOBS[job_id]["output_path"] = str(output_path) | |
| log_progress(job_id, "โ ๅฆ็ๅฎไบ๏ผใใฆใณใญใผใใใฟใณใใฏใชใใฏใใฆใใ ใใใ", 100) | |
| except Exception as e: | |
| import traceback | |
| with JOBS_LOCK: | |
| JOBS[job_id]["status"] = "error" | |
| JOBS[job_id]["error"] = str(e) | |
| log_progress(job_id, f"โ {e}") | |
| finally: | |
| shutil.rmtree(tmp_dir, ignore_errors=True) | |
| try: | |
| input_path.unlink(missing_ok=True) | |
| except Exception: | |
| pass | |
| # โโ ใซใผใใฃใณใฐ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def index(): | |
| return send_from_directory("static", "index.html") | |
| def pricing(): | |
| return send_from_directory("static", "pricing.html") | |
| def success(): | |
| return send_from_directory("static", "success.html") | |
| def legal(): | |
| return send_from_directory("static", "legal.html") | |
| def auth_mode(): | |
| """ใใญใณใใจใณใใ็กๆ/ๆๆใขใผใใ่ญๅฅใใใใใฎใจใณใใใคใณใ""" | |
| return jsonify({"free_mode": FREE_MODE}) | |
| def pricing_info(): | |
| """ๆ้ใใผใธ็จ: Stripe ใใ Price ๆ ๅ ฑใๅๅพใใฆ่ฟใ""" | |
| if FREE_MODE: | |
| return jsonify({"price": 0, "currency": "jpy", "free_mode": True}) | |
| try: | |
| import stripe | |
| stripe.api_key = STRIPE_SECRET_KEY | |
| price = stripe.Price.retrieve(STRIPE_PRICE_ID) | |
| return jsonify({ | |
| "price": price.unit_amount, | |
| "currency": price.currency, | |
| "interval": price.recurring.interval if price.recurring else "month", | |
| }) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| # โโ Stripe: Checkout ใปใใทใงใณไฝๆ โโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def create_checkout(): | |
| if FREE_MODE: | |
| return jsonify({"error": "Stripe ๆช่จญๅฎ๏ผ้็บใขใผใ๏ผ"}), 400 | |
| try: | |
| import stripe | |
| stripe.api_key = STRIPE_SECRET_KEY | |
| session = stripe.checkout.Session.create( | |
| payment_method_types=["card"], | |
| line_items=[{"price": STRIPE_PRICE_ID, "quantity": 1}], | |
| mode="subscription", | |
| success_url=f"{APP_URL}/success?session_id={{CHECKOUT_SESSION_ID}}", | |
| cancel_url=f"{APP_URL}/pricing", | |
| ) | |
| return jsonify({"url": session.url}) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| # โโ Stripe: Webhook โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def stripe_webhook(): | |
| if FREE_MODE: | |
| return "", 200 | |
| try: | |
| import stripe | |
| stripe.api_key = STRIPE_SECRET_KEY | |
| payload = request.get_data() | |
| sig_header = request.headers.get("Stripe-Signature", "") | |
| event = stripe.Webhook.construct_event(payload, sig_header, STRIPE_WEBHOOK_SECRET) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 400 | |
| etype = event["type"] | |
| obj = event["data"]["object"] | |
| if etype == "checkout.session.completed": | |
| sub_id = obj.get("subscription") | |
| cust_id = obj.get("customer") | |
| email = obj.get("customer_email") or obj.get("customer_details", {}).get("email", "") | |
| token = create_token(cust_id, sub_id, email) | |
| print(f"[WEBHOOK] ๆฐใตใในใฏใชใใทใงใณ: {email} โ token={token[:8]}...") | |
| elif etype in ("customer.subscription.deleted", "customer.subscription.paused"): | |
| sub_id = obj.get("id") | |
| deactivate_token_by_subscription(sub_id) | |
| print(f"[WEBHOOK] ใตใในใฏใชใใทใงใณๅๆญข: {sub_id}") | |
| return "", 200 | |
| # โโ Stripe: ๆๅๅพใใผใฏใณๅๅพ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def get_token(): | |
| """ๆฑบๆธๅฎไบๅพใซ Stripe Session ID ใใใใผใฏใณใ่ฟใ""" | |
| if FREE_MODE: | |
| return jsonify({"token": "FREE_MODE"}) | |
| session_id = request.args.get("session_id", "") | |
| if not session_id: | |
| return jsonify({"error": "session_id ใๅฟ ่ฆใงใ"}), 400 | |
| try: | |
| import stripe | |
| stripe.api_key = STRIPE_SECRET_KEY | |
| session = stripe.checkout.Session.retrieve(session_id) | |
| sub_id = session.get("subscription") | |
| cust_id = session.get("customer") | |
| email = session.get("customer_details", {}).get("email", "") | |
| # ๆขๅญใใผใฏใณใๆขใ๏ผWebhook ใๅ ใซๅฆ็ใใฆใใๅ ดๅ๏ผ | |
| data = load_tokens() | |
| for tok, info in data.items(): | |
| if info.get("subscription_id") == sub_id: | |
| return jsonify({"token": tok, "email": email}) | |
| # Webhook ใใพใ ใชใไฝๆ | |
| token = create_token(cust_id, sub_id, email) | |
| return jsonify({"token": token, "email": email}) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| # โโ ใใผใฏใณๆค่จผ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def verify_token(): | |
| token = request.json.get("token", "") | |
| return jsonify({"valid": is_token_valid(token)}) | |
| # โโ ใใกใคใซใขใใใญใผใ & ๅฆ็ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def upload(): | |
| # ใใผใฏใณ่ช่จผ | |
| token = request.form.get("token", "") | |
| if not is_token_valid(token): | |
| return jsonify({"error": "็กๅนใชใใผใฏใณใงใใๆ้ใใผใธใใ็ป้ฒใใฆใใ ใใใ"}), 401 | |
| if "file" not in request.files: | |
| return jsonify({"error": "ใใกใคใซใใใใพใใ"}), 400 | |
| file = request.files["file"] | |
| if not file.filename: | |
| return jsonify({"error": "ใใกใคใซๅใ็ฉบใงใ"}), 400 | |
| inst_vol = float(request.form.get("inst_vol", 0.15)) | |
| model = request.form.get("model", "htdemucs_6s") | |
| use_mdx = request.form.get("use_mdx", "true").lower() == "true" | |
| job_id = str(uuid.uuid4()) | |
| suffix = Path(file.filename).suffix | |
| input_path = UPLOAD_DIR / f"{job_id}_input{suffix}" | |
| file.save(str(input_path)) | |
| with JOBS_LOCK: | |
| JOBS[job_id] = { | |
| "status": "queued", "progress": [], | |
| "output_path": None, "error": None, | |
| "filename": file.filename, | |
| } | |
| t = threading.Thread( | |
| target=process_job, | |
| args=(job_id, input_path, inst_vol, model, use_mdx), | |
| daemon=True, | |
| ) | |
| t.start() | |
| return jsonify({"job_id": job_id}) | |
| # โโ ้ฒๆ SSE โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def progress(job_id: str): | |
| def generate(): | |
| import time | |
| sent = 0 | |
| while True: | |
| with JOBS_LOCK: | |
| if job_id not in JOBS: | |
| yield 'data: {"error":"not found"}\n\n'; return | |
| job = JOBS[job_id] | |
| new_ents = job["progress"][sent:] | |
| sent += len(new_ents) | |
| status = job["status"] | |
| for e in new_ents: | |
| yield f"data: {json.dumps(e, ensure_ascii=False)}\n\n" | |
| if status in ("done", "error"): | |
| yield f"data: {json.dumps({'status': status})}\n\n"; return | |
| time.sleep(0.5) | |
| return Response(generate(), mimetype="text/event-stream", | |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) | |
| def status(job_id: str): | |
| with JOBS_LOCK: | |
| job = JOBS.get(job_id) | |
| if not job: | |
| return jsonify({"error": "not found"}), 404 | |
| return jsonify({"status": job["status"], "error": job.get("error")}) | |
| # โโ ใใฆใณใญใผใ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def download(job_id: str): | |
| with JOBS_LOCK: | |
| job = JOBS.get(job_id) | |
| if not job or job["status"] != "done": | |
| return jsonify({"error": "ใใกใคใซๆชๆบๅ"}), 404 | |
| output_path = Path(job["output_path"]) | |
| if not output_path.exists(): | |
| return jsonify({"error": "ใใกใคใซใ่ฆใคใใใพใใ"}), 404 | |
| download_name = Path(job.get("filename", "audio")).stem + "_panned.wav" | |
| return send_file(str(output_path), as_attachment=True, | |
| download_name=download_name, mimetype="audio/wav") | |
| # โโ ่ตทๅ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| if __name__ == "__main__": | |
| mode = "FREE๏ผStripe ๆช่จญๅฎ๏ผ" if FREE_MODE else "ๆๆ๏ผStripe ๆๅน๏ผ" | |
| print("=" * 52) | |
| print(f"๐ต HarmoSplit ่ตทๅไธญ... ใขใผใ: {mode}") | |
| print(f" http://localhost:{PORT}") | |
| print("=" * 52) | |
| flask_app.run(host="0.0.0.0", port=PORT, debug=False, threaded=True) | |