| import os |
| os.environ["OMP_NUM_THREADS"] = "1" |
| os.environ["MKL_NUM_THREADS"] = "1" |
|
|
| import io |
| import uuid |
| import shutil |
| import numpy as np |
| import librosa |
| import soundfile as sf |
| import pyloudnorm as pyln |
| import torch |
| from flask import Flask, request, send_file, jsonify, make_response |
| from flask_cors import CORS |
| from scipy.signal import butter, lfilter |
| from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain |
| import subprocess |
| from pydub import AudioSegment |
| from concurrent.futures import ThreadPoolExecutor |
|
|
| app = Flask(__name__) |
| CORS(app) |
|
|
| has_gpu = torch.cuda.is_available() |
| device_type = "cuda" if has_gpu else "cpu" |
|
|
| sr = 44100 |
| |
| model_name = "htdemucs_ft" if has_gpu else "htdemucs" |
| target_loudness = -9.0 |
|
|
| def load_mono(file_path, duration=None): |
| if not os.path.exists(file_path): |
| return np.zeros(sr * 5) |
| y, _ = librosa.load(file_path, sr=sr, mono=True, duration=duration) |
| return y |
|
|
| def separate_stems(input_file, job_id): |
| out_dir = f"sep_{job_id}" |
| process_file = input_file |
|
|
| if not has_gpu: |
| process_file = f"trim_{job_id}_{os.path.basename(input_file)}" |
| audio = AudioSegment.from_file(input_file) |
| audio[:30000].export(process_file, format="wav") |
|
|
| cmd = [ |
| "demucs", "-n", model_name, |
| "--out", out_dir, |
| "--device", device_type, |
| process_file |
| ] |
| |
| if not has_gpu: |
| |
| |
| cmd.extend(["-j", "1", "--segment", "7", "--shifts", "0"]) |
| |
| subprocess.run(cmd, check=True) |
| |
| base = os.path.splitext(os.path.basename(process_file))[0] |
| stem_dir = os.path.join(out_dir, model_name, base) |
| |
| if not has_gpu and os.path.exists(process_file): |
| os.remove(process_file) |
|
|
| return { |
| "drums": os.path.join(stem_dir, "drums.wav"), |
| "bass": os.path.join(stem_dir, "bass.wav"), |
| "other": os.path.join(stem_dir, "other.wav") |
| }, out_dir |
|
|
| @app.route('/', methods=['GET']) |
| def health(): |
| return jsonify({ |
| "status": "ready", |
| "hardware": device_type, |
| "model": model_name, |
| "cpu_limit": "30s" if not has_gpu else "none" |
| }), 200 |
|
|
| @app.route('/fuse', methods=['POST']) |
| def fuse_api(): |
| job_id = uuid.uuid4().hex[:8] |
| temp_files, cleanup_dirs = [], [] |
| try: |
| t_req = request.files.get('melody') |
| m_req = request.files.get('style') |
| if not t_req or not m_req: |
| return jsonify({"error": "missing files"}), 400 |
|
|
| t_path, m_path = f"t_{job_id}.wav", f"m_{job_id}.wav" |
| t_req.save(t_path) |
| m_req.save(m_path) |
| temp_files.extend([t_path, m_path]) |
|
|
| if has_gpu: |
| with ThreadPoolExecutor(max_workers=2) as executor: |
| f_t = executor.submit(separate_stems, t_path, f"t_{job_id}") |
| f_m = executor.submit(separate_stems, m_path, f"m_{job_id}") |
| t_stems, t_dir = f_t.result() |
| m_stems, m_dir = f_m.result() |
| else: |
| t_stems, t_dir = separate_stems(t_path, f"t_{job_id}") |
| m_stems, m_dir = separate_stems(m_path, f"m_{job_id}") |
| |
| cleanup_dirs.extend([t_dir, m_dir]) |
|
|
| limit = 30.0 if not has_gpu else None |
| t_other = load_mono(t_stems["other"], duration=limit) |
| m_drums = load_mono(m_stems["drums"], duration=limit) |
| m_bass = load_mono(m_stems["bass"], duration=limit) |
|
|
| target_len = min(len(t_other), len(m_drums)) |
| fusion = (1.0 * m_drums[:target_len] + 1.0 * m_bass[:target_len] + 1.2 * t_other[:target_len]) |
| |
| max_val = np.max(np.abs(fusion)) |
| if max_val > 0: |
| fusion = fusion / max_val |
|
|
| board = Pedalboard([Compressor(threshold_db=-20), Limiter()]) |
| fusion_mastered = board(fusion, sr) |
|
|
| buf = io.BytesIO() |
| sf.write(buf, fusion_mastered, sr, format='WAV') |
| buf.seek(0) |
| return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion.wav") |
|
|
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
| finally: |
| for f in temp_files: |
| if os.path.exists(f): os.remove(f) |
| for d in cleanup_dirs: |
| if os.path.exists(d): shutil.rmtree(d, ignore_errors=True) |
|
|
| if __name__ == "__main__": |
| app.run(host='0.0.0.0', port=7860) |