Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.express as px | |
| import scipy.io.wavfile | |
| import io | |
| BASE_PATH = 'Models' | |
| st.set_page_config(layout="wide", page_title="Audio Source Separation Inspector") | |
| def process_audio(file_path, gain_factor): | |
| try: | |
| # 1. FIX: Check if file is actually a Git LFS pointer (text file) | |
| with open(file_path, 'rb') as f: | |
| header = f.read(50) | |
| if header.startswith(b'version https://git-lfs'): | |
| st.error(f"❌ **LFS Error:** `{os.path.basename(file_path)}` is a Git LFS pointer, not a WAV file. Run `git lfs pull` in your terminal.") | |
| return None | |
| sample_rate, data = scipy.io.wavfile.read(file_path) | |
| if data.dtype == np.int16: | |
| data = data.astype(np.float32) / 32768.0 | |
| elif data.dtype == np.int32: | |
| data = data.astype(np.float32) / 2147483648.0 | |
| data = data * gain_factor | |
| data = np.clip(data, -1.0, 1.0) | |
| data = (data * 32767).astype(np.int16) | |
| virtual_file = io.BytesIO() | |
| scipy.io.wavfile.write(virtual_file, sample_rate, data) | |
| return virtual_file | |
| except Exception as e: | |
| st.error(f"Error processing audio: {e}") | |
| return file_path | |
| def get_subdirs(path): | |
| if not os.path.exists(path): | |
| return [] | |
| return [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))] | |
| def load_spectrogram_interactive(pt_path, title="Spectrogram"): | |
| try: | |
| # 2. FIX: Added weights_only=False to fix PyTorch 2.6+ error | |
| spec_tensor = torch.load(pt_path, map_location='cpu', weights_only=False) | |
| if spec_tensor.dim() == 4: | |
| spec_tensor = spec_tensor[0] | |
| if spec_tensor.dim() == 3: | |
| spec_data = spec_tensor.mean(dim=0).numpy() | |
| else: | |
| spec_data = spec_tensor.numpy() | |
| if spec_data.min() >= 0: | |
| spec_data = np.log1p(spec_data) | |
| fig = px.imshow( | |
| spec_data, | |
| origin='lower', | |
| aspect='auto', | |
| color_continuous_scale='Viridis', | |
| labels=dict(x="Time Frame", y="Frequency Bin", color="Log Magnitude"), | |
| title=title | |
| ) | |
| fig.update_layout(margin=dict(l=0, r=0, t=30, b=0), height=300) | |
| return fig | |
| except Exception as e: | |
| st.error(f"Error loading spectrogram: {e}") | |
| return None | |
| def load_feature_map_interactive(pt_path): | |
| try: | |
| # 3. FIX: Added weights_only=False here as well | |
| feat_tensor = torch.load(pt_path, map_location='cpu', weights_only=False) | |
| if feat_tensor.dim() == 4: | |
| feat_tensor = feat_tensor[0] | |
| mean_activation = feat_tensor.mean(dim=0).numpy() | |
| fig = px.imshow( | |
| mean_activation, | |
| origin='lower', | |
| aspect='auto', | |
| color_continuous_scale='Viridis', | |
| labels=dict(x="Time", y="Freq/Feature", color="Activation"), | |
| title=f"Mean Activation (Shape: {list(feat_tensor.shape)})" | |
| ) | |
| fig.update_layout(margin=dict(l=0, r=0, t=40, b=0)) | |
| return fig | |
| except Exception as e: | |
| return None | |
| st.title("🎵 Audio Source Separation Inspector") | |
| st.markdown(""" | |
| ### Model Interpretation Guide | |
| This tool helps you evaluate how well the model separates audio sources. | |
| * **Audio Quality:** Listen for "artifacts" (robotic sounds or clicking) in the Prediction compared to the Target. | |
| * **Spectrogram Clarity:** In the visuals below, distinct horizontal lines represent clear tones. Vertical smear usually indicates percussion or noise. | |
| * **Error Analysis:** If the Prediction looks "blurry" compared to the Target, the model is losing high-frequency details. | |
| """) | |
| if not os.path.exists(BASE_PATH): | |
| st.error(f"Models directory not found at {BASE_PATH}. Please ensure your data was uploaded correctly.") | |
| st.stop() | |
| models = get_subdirs(BASE_PATH) | |
| selected_model = st.sidebar.selectbox("Select Model", models) | |
| st.sidebar.markdown("### Audio Settings") | |
| volume_boost = st.sidebar.slider( | |
| "Volume Boost (Gain)", | |
| min_value=1.0, | |
| max_value=20.0, | |
| value=1.0, | |
| step=0.5, | |
| help="Digitally increases the amplitude of the audio signal." | |
| ) | |
| if selected_model: | |
| model_path = os.path.join(BASE_PATH, selected_model) | |
| artifacts_path = os.path.join(model_path, "test_artifacts") | |
| if os.path.exists(artifacts_path): | |
| samples = get_subdirs(artifacts_path) | |
| samples.sort(key=lambda x: int(x.split('_')[-1]) if '_' in x else 0) | |
| selected_sample = st.sidebar.selectbox("Select Sample ID", samples) | |
| if selected_sample: | |
| sample_path = os.path.join(artifacts_path, selected_sample) | |
| audio_dir = os.path.join(sample_path, "audio") | |
| specs_dir = os.path.join(sample_path, "specs") | |
| feats_dir = os.path.join(sample_path, "feats") | |
| all_files = os.listdir(audio_dir) | |
| target_files = [f for f in all_files if f.startswith("target_") and f.endswith(".wav")] | |
| classes = [f.replace("target_", "").replace(".wav", "") for f in target_files] | |
| selected_class = st.sidebar.selectbox("Focus Class", classes) | |
| tab1, tab2, tab3 = st.tabs(["🎧 Audio & Spectrograms", "🧠 Internal Activations", "📊 Model Metadata"]) | |
| with tab1: | |
| st.header(f"Sample {selected_sample} | Focus: {selected_class.capitalize()}") | |
| st.subheader("1. Mixture (Input)") | |
| st.markdown("The raw input containing all sound sources mixed together.") | |
| mix_audio = os.path.join(audio_dir, "mixture.wav") | |
| mix_spec = os.path.join(specs_dir, "mixture.pt") | |
| c1, c2 = st.columns([1, 3]) | |
| with c1: | |
| if os.path.exists(mix_audio): | |
| st.markdown("**Audio:**") | |
| processed_mix = process_audio(mix_audio, volume_boost) | |
| if processed_mix: | |
| st.audio(processed_mix, format='audio/wav') | |
| with c2: | |
| if os.path.exists(mix_spec): | |
| fig = load_spectrogram_interactive(mix_spec, title="Mixture Mel-Spectrogram") | |
| if fig: st.plotly_chart(fig, width='stretch') | |
| st.divider() | |
| st.subheader(f"2. Target: {selected_class}") | |
| st.markdown(f"**Interpretation:** This is the 'Ground Truth'. Look at the spectrogram structure here—this is the ideal output.") | |
| tgt_audio = os.path.join(audio_dir, f"target_{selected_class}.wav") | |
| tgt_spec = os.path.join(specs_dir, f"target_{selected_class}.pt") | |
| c1, c2 = st.columns([1, 3]) | |
| with c1: | |
| if os.path.exists(tgt_audio): | |
| st.markdown("**Audio:**") | |
| processed_tgt = process_audio(tgt_audio, volume_boost) | |
| if processed_tgt: | |
| st.audio(processed_tgt, format='audio/wav') | |
| with c2: | |
| if os.path.exists(tgt_spec): | |
| fig = load_spectrogram_interactive(tgt_spec, title=f"Target Mel-Spectrogram ({selected_class})") | |
| if fig: st.plotly_chart(fig, width='stretch') | |
| st.divider() | |
| st.subheader(f"3. Prediction: {selected_class}") | |
| st.markdown(f"**Interpretation:** Compare this to the Target above. If you see 'fuzziness' in the dark areas, the model is not silencing background noise correctly.") | |
| pred_audio = os.path.join(audio_dir, f"pred_{selected_class}.wav") | |
| pred_spec = os.path.join(specs_dir, f"pred_{selected_class}.pt") | |
| c1, c2 = st.columns([1, 3]) | |
| with c1: | |
| if os.path.exists(pred_audio): | |
| st.markdown("**Audio:**") | |
| processed_pred = process_audio(pred_audio, volume_boost) | |
| if processed_pred: | |
| st.audio(processed_pred, format='audio/wav') | |
| with c2: | |
| if os.path.exists(pred_spec): | |
| fig = load_spectrogram_interactive(pred_spec, title=f"Predicted Mel-Spectrogram ({selected_class})") | |
| if fig: st.plotly_chart(fig, width='stretch') | |
| with tab2: | |
| st.header("Internal Feature Maps") | |
| st.markdown("These heatmaps visualize the neural network's internal state. Bright spots indicate features the model considers important for separation.") | |
| if os.path.exists(feats_dir): | |
| feat_files = sorted(os.listdir(feats_dir)) | |
| if feat_files: | |
| selected_layer = st.selectbox("Select Probed Layer", feat_files) | |
| if selected_layer: | |
| st.write(f"Layer: **{selected_layer.replace('.pt', '')}**") | |
| fig = load_feature_map_interactive(os.path.join(feats_dir, selected_layer)) | |
| if fig: | |
| st.plotly_chart(fig, width='stretch') | |
| else: | |
| st.warning("No feature maps found for this sample.") | |
| else: | |
| st.error("Features directory not found.") | |
| with tab3: | |
| st.header("Training and Testing Logs") | |
| st.markdown("Use these graphs to check for **Overfitting**. If Training Loss decreases but Test Metrics stagnate or drop, the model is memorizing data rather than learning general features.") | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| results_csv = os.path.join(model_path, "test_results.csv") | |
| if os.path.exists(results_csv): | |
| st.subheader("Test Metrics") | |
| df = pd.read_csv(results_csv) | |
| x_axis = 'Batch_Index' if 'Batch_Index' in df.columns else df.index | |
| numeric_cols = df.select_dtypes(include=np.number).columns | |
| fig = px.line(df, title="Test Metrics", x=x_axis, y=numeric_cols) | |
| st.plotly_chart(fig, width='stretch') | |
| st.dataframe(df, width='stretch') | |
| else: | |
| st.info("No `test_results.csv` found.") | |
| with c2: | |
| loss_csv = os.path.join(model_path, "loss.csv") | |
| if os.path.exists(loss_csv): | |
| st.subheader("Training Loss") | |
| try: | |
| df_loss = pd.read_csv(loss_csv) | |
| x_axis = 'epoch' if 'epoch' in df_loss.columns else df_loss.index | |
| numeric_cols = df_loss.select_dtypes(include=np.number).columns | |
| fig = px.line(df_loss, x=x_axis, y=numeric_cols, title="Loss Curves") | |
| st.plotly_chart(fig, width='stretch') | |
| st.dataframe(df_loss, width='stretch') | |
| except Exception as e: | |
| st.write("Could not parse `loss.csv`.", e) | |
| else: | |
| st.info("No `loss.csv` found.") | |
| else: | |
| st.warning(f"No 'test_artifacts' folder found in {selected_model}") |