neural-mathrock / app.py
anggars's picture
Update app.py
2f9f522 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa
import numpy as np
import gradio as gr
from transformers import (
XLMRobertaForSequenceClassification,
XLMRobertaTokenizer,
WavLMModel,
ClapAudioModel,
ClapProcessor
)
from huggingface_hub import hf_hub_download
import warnings
warnings.filterwarnings('ignore')
# -- CONFIGURATION --
REPO_AUDIO = "anggars/neural-mathrock"
REPO_TEXT_MBTI = "anggars/xlm-mbti"
REPO_TEXT_EMO = "anggars/xlm-emotion"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -- GLOBAL LABELS --
MBTI_LABELS = sorted(["INTJ", "INTP", "ENTJ", "ENTP", "INFJ", "INFP", "ENFJ", "ENFP", "ISTJ", "ISFJ", "ESTJ", "ESFJ", "ISTP", "ISFP", "ESTP", "ESFP"])
EMO_LABELS = sorted(['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral'])
VIBE_LABELS = sorted(["Aggressive", "Atmospheric", "Melancholic", "Technical"])
INTENSITY_LABELS = sorted(["High", "Low", "Medium"])
N_WAVLM_LAYERS = 13
# -- AUDIO ARCHITECTURE --
class ScalarMix(nn.Module):
def __init__(self, n_layers: int):
super().__init__()
self.weights = nn.Parameter(torch.zeros(n_layers))
self.gamma = nn.Parameter(torch.ones(1))
def forward(self, hidden_states_tuple) -> torch.Tensor:
w = F.softmax(self.weights, dim=0)
stack = torch.stack(hidden_states_tuple, dim=0)
mixed = (w.view(-1, 1, 1, 1) * stack).sum(dim=0)
return mixed * self.gamma
class AttentionPooling(nn.Module):
def __init__(self, hidden_size: int):
super().__init__()
self.attn = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.Tanh(),
nn.Linear(hidden_size // 2, 1),
)
def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
scores = self.attn(x)
if mask is not None:
scores = scores.masked_fill(~mask.unsqueeze(-1).bool(), float("-inf"))
w = F.softmax(scores, dim=1)
return (x * w).sum(dim=1)
class TaskHead(nn.Module):
def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.3):
super().__init__()
mid = in_dim // 2
self.net = nn.Sequential(
nn.Linear(in_dim, mid), nn.LayerNorm(mid), nn.GELU(),
nn.Dropout(dropout), nn.Linear(mid, out_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class UnifiedAudioModel(nn.Module):
def __init__(self):
super().__init__()
self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base-plus", output_hidden_states=True)
self.scalar_mix = ScalarMix(N_WAVLM_LAYERS)
self.wavlm_pool = AttentionPooling(768)
self.clap = ClapAudioModel.from_pretrained("laion/clap-htsat-fused")
self.fusion = nn.Sequential(
nn.Linear(768 + 768, 768), nn.LayerNorm(768), nn.GELU(), nn.Dropout(0.3),
nn.Linear(768, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.2),
)
self.mbti_head = TaskHead(512, len(MBTI_LABELS), dropout=0.3)
self.emo_head = TaskHead(512, len(EMO_LABELS), dropout=0.3)
self.vibe_head = TaskHead(512, len(VIBE_LABELS), dropout=0.2)
self.int_head = TaskHead(512, len(INTENSITY_LABELS), dropout=0.2)
self.tempo_head = nn.Sequential(nn.Linear(512, 64), nn.GELU(), nn.Linear(64, 1))
def forward(self, wav_16k: torch.Tensor, clap_feat: torch.Tensor, is_longer: torch.Tensor) -> dict:
wavlm_out = self.wavlm(wav_16k, output_hidden_states=True)
w_feat = self.wavlm_pool(self.scalar_mix(wavlm_out.hidden_states))
c_feat = self.clap(clap_feat, is_longer=is_longer).pooler_output
fused = self.fusion(torch.cat([w_feat, c_feat], dim=-1))
return {
"mbti": self.mbti_head(fused),
"emo": self.emo_head(fused),
"vibe": self.vibe_head(fused),
"int": self.int_head(fused),
"tmp": self.tempo_head(fused).squeeze(-1),
}
# -- MODEL INITIALIZATION --
print("Fetching Model Weights...")
ckpt_path = hf_hub_download(repo_id=REPO_AUDIO, filename="model.pt")
ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=False)
audio_model = UnifiedAudioModel().to(DEVICE)
audio_model.load_state_dict(ckpt['model_state'], strict=False)
audio_model.eval()
clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused")
tokenizer = XLMRobertaTokenizer.from_pretrained(REPO_TEXT_MBTI)
text_mbti_model = XLMRobertaForSequenceClassification.from_pretrained(REPO_TEXT_MBTI).to(DEVICE).eval()
text_emo_model = XLMRobertaForSequenceClassification.from_pretrained(REPO_TEXT_EMO).to(DEVICE).eval()
# -- DSP TIMBRE & PLAYING STYLE PROFILER (HPSS) --
def extract_riff_nuance(y, sr):
rms = np.mean(librosa.feature.rms(y=y))
tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
bpm = tempo[0] if isinstance(tempo, np.ndarray) else tempo
# Ekstraksi Brightness
centroid = np.mean(librosa.feature.spectral_centroid(y=y, sr=sr))
# Deteksi Distorsi Kasar
zcr = np.mean(librosa.feature.zero_crossing_rate(y))
# HARMONIC-PERCUSSIVE SOURCE SEPARATION (Deteksi Gaya Main)
# Ini bakal misahin suara strumming ngawang (Harmonic) sama suara tapping/beat (Percussive)
y_harmonic, y_percussive = librosa.effects.hpss(y)
rms_h = np.mean(librosa.feature.rms(y=y_harmonic))
rms_p = np.mean(librosa.feature.rms(y=y_percussive))
# Kalau rasio tinggi -> Tapping/Plucking (Ceria/Teknikal)
# Kalau rasio rendah -> Strumming/Wall of Sound (Sedih/Melankolis)
perc_ratio = rms_p / (rms_h + 1e-6)
return rms, bpm, centroid, zcr, perc_ratio
# -- ANALYSIS ENGINE --
def analyze_track(audio_path, lyrics_input):
if not audio_path:
return [{"Error": "No audio"}] * 5
try:
wav_16k, sr_16k = librosa.load(audio_path, sr=16000)
# Ekstraksi Nuansa Fisik
p_rms, p_bpm, p_centroid, p_zcr, p_perc_ratio = extract_riff_nuance(wav_16k, sr_16k)
chunk_len_16k = 16000 * 15
chunks_16k = [wav_16k[i:i + chunk_len_16k] for i in range(0, len(wav_16k), chunk_len_16k) if len(wav_16k[i:i+chunk_len_16k]) >= 16000]
if not chunks_16k:
return [{"Error": "Audio segment too short"}] * 5
a_mbti, a_emo, a_vibe, a_int = [], [], [], []
with torch.no_grad():
for chunk in chunks_16k:
iv = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).to(DEVICE)
chunk_48k = librosa.resample(chunk, orig_sr=16000, target_sr=48000)
proc_out = clap_processor(audio=chunk_48k, return_tensors="pt", sampling_rate=48000)
c_feat = proc_out["input_features"].to(DEVICE)
is_longer = proc_out.get("is_longer", torch.tensor([False])).to(DEVICE)
out = audio_model(iv, c_feat, is_longer)
a_mbti.append(F.softmax(out["mbti"], dim=1))
a_emo.append(F.softmax(out["emo"], dim=1))
a_vibe.append(F.softmax(out["vibe"], dim=1))
a_int.append(F.softmax(out["int"], dim=1))
avg_a_mbti = torch.stack(a_mbti).mean(dim=0)
avg_a_emo = torch.stack(a_emo).mean(dim=0)
avg_a_vibe = torch.stack(a_vibe).mean(dim=0)
max_a_int, _ = torch.max(torch.stack(a_int), dim=0)
# LATE FUSION TEXT
has_lyrics = lyrics_input and len(str(lyrics_input).strip()) > 15
if has_lyrics:
t_in = tokenizer(str(lyrics_input), truncation=True, padding=True, max_length=256, return_tensors="pt").to(DEVICE)
with torch.no_grad():
t_mbti_probs = F.softmax(text_mbti_model(**t_in).logits, dim=1)
t_emo_probs = F.softmax(text_emo_model(**t_in).logits, dim=1)
final_mbti = (avg_a_mbti * 0.7) + (t_mbti_probs * 0.3)
final_emo = (avg_a_emo * 0.3) + (t_emo_probs * 0.7)
else:
final_mbti = avg_a_mbti
final_emo = avg_a_emo
# KALIBRASI NUANSA RIFF MURNI (HPSS & Centroid)
def apply_riff_nuance(probs, labels, task_type):
p = probs.cpu().squeeze().numpy()
res = {labels[i]: float(p[i]) for i in range(len(labels))}
is_distorted = p_zcr > 0.12
# Deteksi Gaya Main Gitar
is_tappy_math = p_perc_ratio > 0.35 and p_centroid > 2200 and not is_distorted
is_strummy_emo = p_perc_ratio <= 0.35 and not is_distorted
if task_type == "emotion":
# Nuke noise
for outlier in ['disgust', 'annoyance', 'fear', 'confusion']:
if outlier in res: res[outlier] *= 0.10
if is_tappy_math:
# Riff Tapping Ceria (Gracias / Playing God)
for neg in ['grief', 'sadness', 'remorse', 'anger']:
if neg in res: res[neg] *= 0.15
for pos in ['joy', 'amusement', 'excitement', 'admiration']:
if pos in res: res[pos] *= 4.00
elif is_strummy_emo:
# Riff Strumming Sedih (i wanna quit / Never Meant)
if 'anger' in res: res['anger'] *= 0.30
for pos in ['joy', 'amusement', 'excitement']:
if pos in res: res[pos] *= 0.15
for neg in ['grief', 'sadness', 'melancholy']:
if neg in res: res[neg] *= 3.00
elif is_distorted:
# Riff Distorsi Kasar (GOAT / Shibuya keras)
if 'joy' in res: res['joy'] *= 0.10
if 'anger' in res: res['anger'] *= 3.00
if task_type == "vibe":
if is_tappy_math:
if 'Melancholic' in res: res['Melancholic'] *= 0.20
if 'Technical' in res: res['Technical'] *= 3.50
if 'Atmospheric' in res: res['Atmospheric'] *= 1.50
elif is_strummy_emo:
if 'Technical' in res: res['Technical'] *= 0.70
if 'Melancholic' in res: res['Melancholic'] *= 3.50
elif is_distorted:
if 'Melancholic' in res: res['Melancholic'] *= 0.30
if 'Aggressive' in res: res['Aggressive'] *= 3.50
if task_type == "intensity":
if p_rms < 0.08:
res['Low'] *= 3.0; res['High'] *= 0.1
elif p_rms > 0.22:
res['High'] *= 1.5; res['Low'] *= 0.1
else:
res['Medium'] *= 1.5
total = sum(res.values())
if total > 0:
res = {k: v / total for k, v in res.items()}
return dict(sorted(res.items(), key=lambda x: x[1], reverse=True)[:3])
# FORMAT TEMPO MURNI
tempo_dict = {"Fast": 0.0, "Medium": 0.0, "Slow": 0.0}
if p_bpm >= 135:
tempo_dict["Fast"] = 0.90; tempo_dict["Medium"] = 0.10
elif 100 <= p_bpm < 135:
tempo_dict["Fast"] = max(0.0, (p_bpm - 100) / 35.0)
tempo_dict["Medium"] = 0.80
tempo_dict["Slow"] = max(0.0, (135 - p_bpm) / 35.0)
else:
tempo_dict["Slow"] = 0.90; tempo_dict["Medium"] = 0.10
t_total = sum(tempo_dict.values())
if t_total > 0: tempo_dict = {k: float(v / t_total) for k, v in tempo_dict.items()}
return [
apply_riff_nuance(final_mbti, MBTI_LABELS, "mbti"),
apply_riff_nuance(final_emo, EMO_LABELS, "emotion"),
apply_riff_nuance(avg_a_vibe, VIBE_LABELS, "vibe"),
apply_riff_nuance(max_a_int, INTENSITY_LABELS, "intensity"),
tempo_dict
]
except Exception as e:
return [{"Error": str(e)}] * 5
# -- INTERFACE --
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# Neural Math Rock Multimodal Analysis")
gr.Markdown("Identify personality and emotional states from music audio and lyrics.")
with gr.Row():
with gr.Column():
audio_box = gr.Audio(type="filepath", label="Audio Source")
lyrics_box = gr.Textbox(lines=8, label="Lyrics Source", placeholder="Paste lyrics here for hybrid analysis...")
run_btn = gr.Button("RUN ANALYSIS", variant="primary")
with gr.Column():
res_mbti = gr.Label(label="Personality (MBTI)")
res_emo = gr.Label(label="Emotional State")
res_vibe = gr.Label(label="Acoustic Vibe")
res_int = gr.Label(label="Intensity Level")
res_tmp = gr.Label(label="Tempo Classification")
run_btn.click(
fn=analyze_track,
inputs=[audio_box, lyrics_box],
outputs=[res_mbti, res_emo, res_vibe, res_int, res_tmp]
)
if __name__ == "__main__":
demo.launch()