fusionAI / app.py
gere's picture
Update app.py
1ceddcb verified
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
import io
import uuid
import shutil
import numpy as np
import librosa
import soundfile as sf
import pyloudnorm as pyln
import torch
from flask import Flask, request, send_file, jsonify, make_response
from flask_cors import CORS
from scipy.signal import butter, lfilter
from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
import subprocess
from pydub import AudioSegment
from concurrent.futures import ThreadPoolExecutor
app = Flask(__name__)
CORS(app)
has_gpu = torch.cuda.is_available()
device_type = "cuda" if has_gpu else "cpu"
sr = 44100
# use the heavy transformer for gpu, but a lite model for cpu to stop the 7-minute wait
model_name = "htdemucs_ft" if has_gpu else "htdemucs"
target_loudness = -9.0
def load_mono(file_path, duration=None):
if not os.path.exists(file_path):
return np.zeros(sr * 5)
y, _ = librosa.load(file_path, sr=sr, mono=True, duration=duration)
return y
def separate_stems(input_file, job_id):
out_dir = f"sep_{job_id}"
process_file = input_file
if not has_gpu:
process_file = f"trim_{job_id}_{os.path.basename(input_file)}"
audio = AudioSegment.from_file(input_file)
audio[:30000].export(process_file, format="wav")
cmd = [
"demucs", "-n", model_name,
"--out", out_dir,
"--device", device_type,
process_file
]
if not has_gpu:
# segment 7 is required for htdemucs safety
# shifts 0 is the fastest possible processing mode
cmd.extend(["-j", "1", "--segment", "7", "--shifts", "0"])
subprocess.run(cmd, check=True)
base = os.path.splitext(os.path.basename(process_file))[0]
stem_dir = os.path.join(out_dir, model_name, base)
if not has_gpu and os.path.exists(process_file):
os.remove(process_file)
return {
"drums": os.path.join(stem_dir, "drums.wav"),
"bass": os.path.join(stem_dir, "bass.wav"),
"other": os.path.join(stem_dir, "other.wav")
}, out_dir
@app.route('/', methods=['GET'])
def health():
return jsonify({
"status": "ready",
"hardware": device_type,
"model": model_name,
"cpu_limit": "30s" if not has_gpu else "none"
}), 200
@app.route('/fuse', methods=['POST'])
def fuse_api():
job_id = uuid.uuid4().hex[:8]
temp_files, cleanup_dirs = [], []
try:
t_req = request.files.get('melody')
m_req = request.files.get('style')
if not t_req or not m_req:
return jsonify({"error": "missing files"}), 400
t_path, m_path = f"t_{job_id}.wav", f"m_{job_id}.wav"
t_req.save(t_path)
m_req.save(m_path)
temp_files.extend([t_path, m_path])
if has_gpu:
with ThreadPoolExecutor(max_workers=2) as executor:
f_t = executor.submit(separate_stems, t_path, f"t_{job_id}")
f_m = executor.submit(separate_stems, m_path, f"m_{job_id}")
t_stems, t_dir = f_t.result()
m_stems, m_dir = f_m.result()
else:
t_stems, t_dir = separate_stems(t_path, f"t_{job_id}")
m_stems, m_dir = separate_stems(m_path, f"m_{job_id}")
cleanup_dirs.extend([t_dir, m_dir])
limit = 30.0 if not has_gpu else None
t_other = load_mono(t_stems["other"], duration=limit)
m_drums = load_mono(m_stems["drums"], duration=limit)
m_bass = load_mono(m_stems["bass"], duration=limit)
target_len = min(len(t_other), len(m_drums))
fusion = (1.0 * m_drums[:target_len] + 1.0 * m_bass[:target_len] + 1.2 * t_other[:target_len])
max_val = np.max(np.abs(fusion))
if max_val > 0:
fusion = fusion / max_val
board = Pedalboard([Compressor(threshold_db=-20), Limiter()])
fusion_mastered = board(fusion, sr)
buf = io.BytesIO()
sf.write(buf, fusion_mastered, sr, format='WAV')
buf.seek(0)
return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion.wav")
except Exception as e:
return jsonify({"error": str(e)}), 500
finally:
for f in temp_files:
if os.path.exists(f): os.remove(f)
for d in cleanup_dirs:
if os.path.exists(d): shutil.rmtree(d, ignore_errors=True)
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)