fusionmodel / app.py
gere's picture
Update app.py
34b4ff9 verified
import os
import io
import uuid
import shutil
import numpy as np
import librosa
import soundfile as sf
import pyloudnorm as pyln
from flask import Flask, request, send_file, jsonify
from flask_cors import CORS
from scipy.signal import butter, lfilter
from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
import subprocess
from pydub import AudioSegment
app = Flask(__name__)
CORS(app)
SR = 44100
TARGET_LOUDNESS = -9.0
def convert_to_wav(input_path):
if input_path.lower().endswith(".mp3"):
wav_path = input_path.rsplit(".", 1)[0] + f"_{uuid.uuid4().hex}.wav"
AudioSegment.from_mp3(input_path).export(wav_path, format="wav")
return wav_path
return input_path
def load_mono(file):
y, _ = librosa.load(file, sr=SR, mono=True)
return y
def normalize_audio(y):
return y / (np.max(np.abs(y)) + 1e-9)
def highpass(data, cutoff):
b, a = butter(4, cutoff / (SR / 2), btype='high')
return lfilter(b, a, data)
def lowpass(data, cutoff):
b, a = butter(4, cutoff / (SR / 2), btype='low')
return lfilter(b, a, data)
def detect_key(y):
chroma = librosa.feature.chroma_cqt(y=y, sr=SR)
return np.argmax(np.sum(chroma, axis=1))
def match_key(source, target):
key_s = detect_key(source)
key_t = detect_key(target)
shift = key_t - key_s
return librosa.effects.pitch_shift(source, sr=SR, n_steps=float(shift))
def beat_sync_warp(source, target):
tempo_t, _ = librosa.beat.beat_track(y=target, sr=SR)
tempo_s, _ = librosa.beat.beat_track(y=source, sr=SR)
tempo_t = float(np.atleast_1d(tempo_t)[0])
tempo_s = float(np.atleast_1d(tempo_s)[0])
if tempo_t <= 0 or tempo_s <= 0:
return librosa.util.fix_length(source, size=len(target))
rate = tempo_s / tempo_t
warped = librosa.effects.time_stretch(source, rate=float(rate))
warped = librosa.util.fix_length(warped, size=len(target))
return warped
def separate_stems(input_file, job_id):
out_dir = f"separated_{job_id}"
subprocess.run(["demucs", "-n", "htdemucs", "--out", out_dir, input_file], check=True)
base = os.path.splitext(os.path.basename(input_file))[0]
stem_dir = f"{out_dir}/htdemucs/{base}"
stems = {
"drums": f"{stem_dir}/drums.wav",
"bass": f"{stem_dir}/bass.wav",
"other": f"{stem_dir}/other.wav",
"vocals": f"{stem_dir}/vocals.wav"
}
return stems, out_dir
@app.route('/', methods=['GET'])
def health():
return jsonify({"status": "ready"}), 200
@app.route('/fuse', methods=['POST'])
def fuse_api():
job_id = str(uuid.uuid4())
temp_files = []
cleanup_dirs = []
try:
trad_req = request.files.get('melody')
modern_req = request.files.get('style')
if not trad_req or not modern_req:
return jsonify({"error": "missing files"}), 400
t_path = f"trad_{job_id}.wav"
m_path = f"mod_{job_id}.wav"
trad_req.save(t_path)
modern_req.save(m_path)
temp_files.extend([t_path, m_path])
t_wav = convert_to_wav(t_path)
m_wav = convert_to_wav(m_path)
if t_wav != t_path: temp_files.append(t_wav)
if m_wav != m_path: temp_files.append(m_wav)
t_stems, t_dir = separate_stems(t_wav, f"t_{job_id}")
m_stems, m_dir = separate_stems(m_wav, f"m_{job_id}")
cleanup_dirs.extend([t_dir, m_dir])
t_other = load_mono(t_stems["other"])
t_bass = load_mono(t_stems["bass"])
m_drums = load_mono(m_stems["drums"])
m_bass = load_mono(m_stems["bass"])
m_other = load_mono(m_stems["other"])
target_len = min(len(t_other), len(m_drums))
t_other = t_other[:target_len]
t_bass = t_bass[:target_len]
m_drums = m_drums[:target_len]
m_bass = m_bass[:target_len]
m_other = m_other[:target_len]
t_other = match_key(t_other, m_other)
t_bass = match_key(t_bass, m_bass)
t_other = beat_sync_warp(t_other, m_drums)
t_bass = beat_sync_warp(t_bass, m_drums)
t_other = highpass(t_other, 120)
t_bass = highpass(t_bass, 60)
m_bass = lowpass(m_bass, 250)
m_drums = lowpass(m_drums, 12000)
m_other = highpass(m_other, 150)
fusion = (1.0 * m_drums + 1.0 * m_bass + 1.2 * t_other + 0.5 * m_other + 0.8 * t_bass)
fusion = normalize_audio(fusion)
board = Pedalboard([HighpassFilter(30), LowpassFilter(18000), Compressor(threshold_db=-20, ratio=2), Gain(2.0), Limiter(threshold_db=-0.5)])
fusion_mastered = board(fusion, SR)
meter = pyln.Meter(SR)
loudness = meter.integrated_loudness(fusion_mastered)
fusion_mastered = pyln.normalize.loudness(fusion_mastered, loudness, TARGET_LOUDNESS)
buf = io.BytesIO()
sf.write(buf, fusion_mastered, SR, format='WAV')
buf.seek(0)
return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion_output.wav")
except Exception as e:
return jsonify({"error": str(e)}), 500
finally:
for f in temp_files:
if os.path.exists(f): os.remove(f)
for d in cleanup_dirs:
if os.path.exists(d): shutil.rmtree(d, ignore_errors=True)
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)