Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import librosa | |
| import numpy as np | |
| import gradio as gr | |
| from transformers import ( | |
| XLMRobertaForSequenceClassification, | |
| XLMRobertaTokenizer, | |
| WavLMModel, | |
| ClapAudioModel, | |
| ClapProcessor | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # -- CONFIGURATION -- | |
| REPO_AUDIO = "anggars/neural-mathrock" | |
| REPO_TEXT_MBTI = "anggars/xlm-mbti" | |
| REPO_TEXT_EMO = "anggars/xlm-emotion" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # -- GLOBAL LABELS -- | |
| MBTI_LABELS = sorted(["INTJ", "INTP", "ENTJ", "ENTP", "INFJ", "INFP", "ENFJ", "ENFP", "ISTJ", "ISFJ", "ESTJ", "ESFJ", "ISTP", "ISFP", "ESTP", "ESFP"]) | |
| EMO_LABELS = sorted(['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']) | |
| VIBE_LABELS = sorted(["Aggressive", "Atmospheric", "Melancholic", "Technical"]) | |
| INTENSITY_LABELS = sorted(["High", "Low", "Medium"]) | |
| N_WAVLM_LAYERS = 13 | |
| # -- AUDIO ARCHITECTURE -- | |
| class ScalarMix(nn.Module): | |
| def __init__(self, n_layers: int): | |
| super().__init__() | |
| self.weights = nn.Parameter(torch.zeros(n_layers)) | |
| self.gamma = nn.Parameter(torch.ones(1)) | |
| def forward(self, hidden_states_tuple) -> torch.Tensor: | |
| w = F.softmax(self.weights, dim=0) | |
| stack = torch.stack(hidden_states_tuple, dim=0) | |
| mixed = (w.view(-1, 1, 1, 1) * stack).sum(dim=0) | |
| return mixed * self.gamma | |
| class AttentionPooling(nn.Module): | |
| def __init__(self, hidden_size: int): | |
| super().__init__() | |
| self.attn = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size // 2), | |
| nn.Tanh(), | |
| nn.Linear(hidden_size // 2, 1), | |
| ) | |
| def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor: | |
| scores = self.attn(x) | |
| if mask is not None: | |
| scores = scores.masked_fill(~mask.unsqueeze(-1).bool(), float("-inf")) | |
| w = F.softmax(scores, dim=1) | |
| return (x * w).sum(dim=1) | |
| class TaskHead(nn.Module): | |
| def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.3): | |
| super().__init__() | |
| mid = in_dim // 2 | |
| self.net = nn.Sequential( | |
| nn.Linear(in_dim, mid), nn.LayerNorm(mid), nn.GELU(), | |
| nn.Dropout(dropout), nn.Linear(mid, out_dim), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.net(x) | |
| class UnifiedAudioModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base-plus", output_hidden_states=True) | |
| self.scalar_mix = ScalarMix(N_WAVLM_LAYERS) | |
| self.wavlm_pool = AttentionPooling(768) | |
| self.clap = ClapAudioModel.from_pretrained("laion/clap-htsat-fused") | |
| self.fusion = nn.Sequential( | |
| nn.Linear(768 + 768, 768), nn.LayerNorm(768), nn.GELU(), nn.Dropout(0.3), | |
| nn.Linear(768, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.2), | |
| ) | |
| self.mbti_head = TaskHead(512, len(MBTI_LABELS), dropout=0.3) | |
| self.emo_head = TaskHead(512, len(EMO_LABELS), dropout=0.3) | |
| self.vibe_head = TaskHead(512, len(VIBE_LABELS), dropout=0.2) | |
| self.int_head = TaskHead(512, len(INTENSITY_LABELS), dropout=0.2) | |
| self.tempo_head = nn.Sequential(nn.Linear(512, 64), nn.GELU(), nn.Linear(64, 1)) | |
| def forward(self, wav_16k: torch.Tensor, clap_feat: torch.Tensor, is_longer: torch.Tensor) -> dict: | |
| wavlm_out = self.wavlm(wav_16k, output_hidden_states=True) | |
| w_feat = self.wavlm_pool(self.scalar_mix(wavlm_out.hidden_states)) | |
| c_feat = self.clap(clap_feat, is_longer=is_longer).pooler_output | |
| fused = self.fusion(torch.cat([w_feat, c_feat], dim=-1)) | |
| return { | |
| "mbti": self.mbti_head(fused), | |
| "emo": self.emo_head(fused), | |
| "vibe": self.vibe_head(fused), | |
| "int": self.int_head(fused), | |
| "tmp": self.tempo_head(fused).squeeze(-1), | |
| } | |
| # -- MODEL INITIALIZATION -- | |
| print("Fetching Model Weights...") | |
| ckpt_path = hf_hub_download(repo_id=REPO_AUDIO, filename="model.pt") | |
| ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=False) | |
| audio_model = UnifiedAudioModel().to(DEVICE) | |
| audio_model.load_state_dict(ckpt['model_state'], strict=False) | |
| audio_model.eval() | |
| clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused") | |
| tokenizer = XLMRobertaTokenizer.from_pretrained(REPO_TEXT_MBTI) | |
| text_mbti_model = XLMRobertaForSequenceClassification.from_pretrained(REPO_TEXT_MBTI).to(DEVICE).eval() | |
| text_emo_model = XLMRobertaForSequenceClassification.from_pretrained(REPO_TEXT_EMO).to(DEVICE).eval() | |
| # -- DSP TIMBRE & PLAYING STYLE PROFILER (HPSS) -- | |
| def extract_riff_nuance(y, sr): | |
| rms = np.mean(librosa.feature.rms(y=y)) | |
| tempo, _ = librosa.beat.beat_track(y=y, sr=sr) | |
| bpm = tempo[0] if isinstance(tempo, np.ndarray) else tempo | |
| # Ekstraksi Brightness | |
| centroid = np.mean(librosa.feature.spectral_centroid(y=y, sr=sr)) | |
| # Deteksi Distorsi Kasar | |
| zcr = np.mean(librosa.feature.zero_crossing_rate(y)) | |
| # HARMONIC-PERCUSSIVE SOURCE SEPARATION (Deteksi Gaya Main) | |
| # Ini bakal misahin suara strumming ngawang (Harmonic) sama suara tapping/beat (Percussive) | |
| y_harmonic, y_percussive = librosa.effects.hpss(y) | |
| rms_h = np.mean(librosa.feature.rms(y=y_harmonic)) | |
| rms_p = np.mean(librosa.feature.rms(y=y_percussive)) | |
| # Kalau rasio tinggi -> Tapping/Plucking (Ceria/Teknikal) | |
| # Kalau rasio rendah -> Strumming/Wall of Sound (Sedih/Melankolis) | |
| perc_ratio = rms_p / (rms_h + 1e-6) | |
| return rms, bpm, centroid, zcr, perc_ratio | |
| # -- ANALYSIS ENGINE -- | |
| def analyze_track(audio_path, lyrics_input): | |
| if not audio_path: | |
| return [{"Error": "No audio"}] * 5 | |
| try: | |
| wav_16k, sr_16k = librosa.load(audio_path, sr=16000) | |
| # Ekstraksi Nuansa Fisik | |
| p_rms, p_bpm, p_centroid, p_zcr, p_perc_ratio = extract_riff_nuance(wav_16k, sr_16k) | |
| chunk_len_16k = 16000 * 15 | |
| chunks_16k = [wav_16k[i:i + chunk_len_16k] for i in range(0, len(wav_16k), chunk_len_16k) if len(wav_16k[i:i+chunk_len_16k]) >= 16000] | |
| if not chunks_16k: | |
| return [{"Error": "Audio segment too short"}] * 5 | |
| a_mbti, a_emo, a_vibe, a_int = [], [], [], [] | |
| with torch.no_grad(): | |
| for chunk in chunks_16k: | |
| iv = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).to(DEVICE) | |
| chunk_48k = librosa.resample(chunk, orig_sr=16000, target_sr=48000) | |
| proc_out = clap_processor(audio=chunk_48k, return_tensors="pt", sampling_rate=48000) | |
| c_feat = proc_out["input_features"].to(DEVICE) | |
| is_longer = proc_out.get("is_longer", torch.tensor([False])).to(DEVICE) | |
| out = audio_model(iv, c_feat, is_longer) | |
| a_mbti.append(F.softmax(out["mbti"], dim=1)) | |
| a_emo.append(F.softmax(out["emo"], dim=1)) | |
| a_vibe.append(F.softmax(out["vibe"], dim=1)) | |
| a_int.append(F.softmax(out["int"], dim=1)) | |
| avg_a_mbti = torch.stack(a_mbti).mean(dim=0) | |
| avg_a_emo = torch.stack(a_emo).mean(dim=0) | |
| avg_a_vibe = torch.stack(a_vibe).mean(dim=0) | |
| max_a_int, _ = torch.max(torch.stack(a_int), dim=0) | |
| # LATE FUSION TEXT | |
| has_lyrics = lyrics_input and len(str(lyrics_input).strip()) > 15 | |
| if has_lyrics: | |
| t_in = tokenizer(str(lyrics_input), truncation=True, padding=True, max_length=256, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| t_mbti_probs = F.softmax(text_mbti_model(**t_in).logits, dim=1) | |
| t_emo_probs = F.softmax(text_emo_model(**t_in).logits, dim=1) | |
| final_mbti = (avg_a_mbti * 0.7) + (t_mbti_probs * 0.3) | |
| final_emo = (avg_a_emo * 0.3) + (t_emo_probs * 0.7) | |
| else: | |
| final_mbti = avg_a_mbti | |
| final_emo = avg_a_emo | |
| # KALIBRASI NUANSA RIFF MURNI (HPSS & Centroid) | |
| def apply_riff_nuance(probs, labels, task_type): | |
| p = probs.cpu().squeeze().numpy() | |
| res = {labels[i]: float(p[i]) for i in range(len(labels))} | |
| is_distorted = p_zcr > 0.12 | |
| # Deteksi Gaya Main Gitar | |
| is_tappy_math = p_perc_ratio > 0.35 and p_centroid > 2200 and not is_distorted | |
| is_strummy_emo = p_perc_ratio <= 0.35 and not is_distorted | |
| if task_type == "emotion": | |
| # Nuke noise | |
| for outlier in ['disgust', 'annoyance', 'fear', 'confusion']: | |
| if outlier in res: res[outlier] *= 0.10 | |
| if is_tappy_math: | |
| # Riff Tapping Ceria (Gracias / Playing God) | |
| for neg in ['grief', 'sadness', 'remorse', 'anger']: | |
| if neg in res: res[neg] *= 0.15 | |
| for pos in ['joy', 'amusement', 'excitement', 'admiration']: | |
| if pos in res: res[pos] *= 4.00 | |
| elif is_strummy_emo: | |
| # Riff Strumming Sedih (i wanna quit / Never Meant) | |
| if 'anger' in res: res['anger'] *= 0.30 | |
| for pos in ['joy', 'amusement', 'excitement']: | |
| if pos in res: res[pos] *= 0.15 | |
| for neg in ['grief', 'sadness', 'melancholy']: | |
| if neg in res: res[neg] *= 3.00 | |
| elif is_distorted: | |
| # Riff Distorsi Kasar (GOAT / Shibuya keras) | |
| if 'joy' in res: res['joy'] *= 0.10 | |
| if 'anger' in res: res['anger'] *= 3.00 | |
| if task_type == "vibe": | |
| if is_tappy_math: | |
| if 'Melancholic' in res: res['Melancholic'] *= 0.20 | |
| if 'Technical' in res: res['Technical'] *= 3.50 | |
| if 'Atmospheric' in res: res['Atmospheric'] *= 1.50 | |
| elif is_strummy_emo: | |
| if 'Technical' in res: res['Technical'] *= 0.70 | |
| if 'Melancholic' in res: res['Melancholic'] *= 3.50 | |
| elif is_distorted: | |
| if 'Melancholic' in res: res['Melancholic'] *= 0.30 | |
| if 'Aggressive' in res: res['Aggressive'] *= 3.50 | |
| if task_type == "intensity": | |
| if p_rms < 0.08: | |
| res['Low'] *= 3.0; res['High'] *= 0.1 | |
| elif p_rms > 0.22: | |
| res['High'] *= 1.5; res['Low'] *= 0.1 | |
| else: | |
| res['Medium'] *= 1.5 | |
| total = sum(res.values()) | |
| if total > 0: | |
| res = {k: v / total for k, v in res.items()} | |
| return dict(sorted(res.items(), key=lambda x: x[1], reverse=True)[:3]) | |
| # FORMAT TEMPO MURNI | |
| tempo_dict = {"Fast": 0.0, "Medium": 0.0, "Slow": 0.0} | |
| if p_bpm >= 135: | |
| tempo_dict["Fast"] = 0.90; tempo_dict["Medium"] = 0.10 | |
| elif 100 <= p_bpm < 135: | |
| tempo_dict["Fast"] = max(0.0, (p_bpm - 100) / 35.0) | |
| tempo_dict["Medium"] = 0.80 | |
| tempo_dict["Slow"] = max(0.0, (135 - p_bpm) / 35.0) | |
| else: | |
| tempo_dict["Slow"] = 0.90; tempo_dict["Medium"] = 0.10 | |
| t_total = sum(tempo_dict.values()) | |
| if t_total > 0: tempo_dict = {k: float(v / t_total) for k, v in tempo_dict.items()} | |
| return [ | |
| apply_riff_nuance(final_mbti, MBTI_LABELS, "mbti"), | |
| apply_riff_nuance(final_emo, EMO_LABELS, "emotion"), | |
| apply_riff_nuance(avg_a_vibe, VIBE_LABELS, "vibe"), | |
| apply_riff_nuance(max_a_int, INTENSITY_LABELS, "intensity"), | |
| tempo_dict | |
| ] | |
| except Exception as e: | |
| return [{"Error": str(e)}] * 5 | |
| # -- INTERFACE -- | |
| with gr.Blocks(theme=gr.themes.Monochrome()) as demo: | |
| gr.Markdown("# Neural Math Rock Multimodal Analysis") | |
| gr.Markdown("Identify personality and emotional states from music audio and lyrics.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_box = gr.Audio(type="filepath", label="Audio Source") | |
| lyrics_box = gr.Textbox(lines=8, label="Lyrics Source", placeholder="Paste lyrics here for hybrid analysis...") | |
| run_btn = gr.Button("RUN ANALYSIS", variant="primary") | |
| with gr.Column(): | |
| res_mbti = gr.Label(label="Personality (MBTI)") | |
| res_emo = gr.Label(label="Emotional State") | |
| res_vibe = gr.Label(label="Acoustic Vibe") | |
| res_int = gr.Label(label="Intensity Level") | |
| res_tmp = gr.Label(label="Tempo Classification") | |
| run_btn.click( | |
| fn=analyze_track, | |
| inputs=[audio_box, lyrics_box], | |
| outputs=[res_mbti, res_emo, res_vibe, res_int, res_tmp] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |