AIvry commited on
Commit
226ddaf
Β·
verified Β·
1 Parent(s): 1cbe9b6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +278 -0
app.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import zipfile
3
+ import tempfile
4
+ import shutil
5
+ from pathlib import Path
6
+ import pandas as pd
7
+ import json
8
+ import os
9
+ import traceback
10
+ import gc
11
+
12
+ # Import your modules
13
+ from engine import compute_mapss_measures
14
+ from models import get_model_config, cleanup_all_models
15
+ from config import DEFAULT_ALPHA
16
+ from utils import clear_gpu_memory
17
+
18
+ def process_audio_files(zip_file, model_name, layer, alpha):
19
+ """
20
+ Process uploaded ZIP file containing audio mixtures.
21
+
22
+ Expected ZIP structure:
23
+ - references/: Contains N reference audio files
24
+ - outputs/: Contains N output audio files
25
+ """
26
+
27
+ if zip_file is None:
28
+ return None, "Please upload a ZIP file"
29
+
30
+ # Create temporary directory for processing
31
+ with tempfile.TemporaryDirectory() as temp_dir:
32
+ temp_path = Path(temp_dir)
33
+
34
+ try:
35
+ # Extract ZIP file
36
+ extract_path = temp_path / "extracted"
37
+ extract_path.mkdir(exist_ok=True)
38
+
39
+ with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
40
+ zip_ref.extractall(extract_path)
41
+
42
+ # Find references and outputs directories
43
+ refs_dir = None
44
+ outs_dir = None
45
+
46
+ # Check for standard structure
47
+ for item in extract_path.iterdir():
48
+ if item.is_dir():
49
+ if item.name.lower() in ['references', 'refs', 'reference']:
50
+ refs_dir = item
51
+ elif item.name.lower() in ['outputs', 'outs', 'output', 'separated']:
52
+ outs_dir = item
53
+
54
+ # If not found at root, check one level deeper
55
+ if refs_dir is None or outs_dir is None:
56
+ for item in extract_path.iterdir():
57
+ if item.is_dir():
58
+ for subitem in item.iterdir():
59
+ if subitem.is_dir():
60
+ if subitem.name.lower() in ['references', 'refs', 'reference']:
61
+ refs_dir = subitem
62
+ elif subitem.name.lower() in ['outputs', 'outs', 'output', 'separated']:
63
+ outs_dir = subitem
64
+
65
+ if refs_dir is None or outs_dir is None:
66
+ return None, "Could not find 'references' and 'outputs' directories in the ZIP file"
67
+
68
+ # Get audio files
69
+ ref_files = sorted([f for f in refs_dir.glob("*.wav")])
70
+ out_files = sorted([f for f in outs_dir.glob("*.wav")])
71
+
72
+ if len(ref_files) == 0:
73
+ return None, "No reference WAV files found"
74
+ if len(out_files) == 0:
75
+ return None, "No output WAV files found"
76
+
77
+ # Create manifest
78
+ manifest = [{
79
+ "mixture_id": "uploaded_mixture",
80
+ "references": [str(f) for f in ref_files],
81
+ "systems": {
82
+ "uploaded_system": [str(f) for f in out_files]
83
+ }
84
+ }]
85
+
86
+ # Validate model and layer
87
+ allowed_models = set(get_model_config(0).keys())
88
+ if model_name not in allowed_models:
89
+ return None, f"Invalid model. Allowed: {', '.join(sorted(allowed_models))}"
90
+
91
+ # Set default layer if needed
92
+ if model_name == "raw":
93
+ layer_final = 0
94
+ else:
95
+ model_defaults = {
96
+ "wavlm": 24, "wav2vec2": 24, "hubert": 24,
97
+ "wavlm_base": 12, "wav2vec2_base": 12, "hubert_base": 12,
98
+ "wav2vec2_xlsr": 24, "ast": 12
99
+ }
100
+ layer_final = layer if layer is not None else model_defaults.get(model_name, 12)
101
+
102
+ # Run experiment with compute_mapss_measures
103
+ results_dir = compute_mapss_measures(
104
+ models=[model_name],
105
+ mixtures=manifest,
106
+ layer=layer_final,
107
+ alpha=alpha,
108
+ verbose=True,
109
+ max_gpus=1, # Limit to 1 GPU for HF Space
110
+ add_ci=False # Disable CI for faster processing
111
+ )
112
+
113
+ # Create output ZIP with results
114
+ output_zip = temp_path / "results.zip"
115
+
116
+ with zipfile.ZipFile(output_zip, 'w') as zipf:
117
+ # Add all CSV files from results
118
+ results_path = Path(results_dir)
119
+ for csv_file in results_path.rglob("*.csv"):
120
+ arcname = str(csv_file.relative_to(results_path.parent))
121
+ zipf.write(csv_file, arcname)
122
+
123
+ # Add params.json
124
+ params_file = results_path / "params.json"
125
+ if params_file.exists():
126
+ zipf.write(params_file, str(params_file.relative_to(results_path.parent)))
127
+
128
+ # Add manifest
129
+ manifest_file = results_path / "manifest_canonical.json"
130
+ if manifest_file.exists():
131
+ zipf.write(manifest_file, str(manifest_file.relative_to(results_path.parent)))
132
+
133
+ # Read the ZIP file to return
134
+ with open(output_zip, 'rb') as f:
135
+ output_data = f.read()
136
+
137
+ # Create a proper file object for Gradio
138
+ output_file_path = temp_path / "download_results.zip"
139
+ with open(output_file_path, 'wb') as f:
140
+ f.write(output_data)
141
+
142
+ return str(output_file_path), "Processing completed successfully!"
143
+
144
+ except Exception as e:
145
+ error_msg = f"Error processing files: {str(e)}\n{traceback.format_exc()}"
146
+ return None, error_msg
147
+ finally:
148
+ # Ensure cleanup happens
149
+ cleanup_all_models()
150
+ clear_gpu_memory()
151
+ gc.collect()
152
+
153
+ # Create Gradio interface
154
+ def create_interface():
155
+ with gr.Blocks(title="MAPSS - Multi-source Audio Perceptual Separation Scores") as demo:
156
+ gr.Markdown("""
157
+ # MAPSS: Multi-source Audio Perceptual Separation Scores
158
+
159
+ This tool evaluates audio source separation quality using Perceptual Similarity (PS) and Perceptual Matching (PM) metrics.
160
+
161
+ ## How to use:
162
+ 1. **Prepare your audio files**: Create a ZIP file with the following structure:
163
+ ```
164
+ your_mixture.zip
165
+ β”œβ”€β”€ references/ # Original clean sources
166
+ β”‚ β”œβ”€β”€ speaker1.wav
167
+ β”‚ β”œβ”€β”€ speaker2.wav
168
+ β”‚ └── ...
169
+ └── outputs/ # Separated outputs from your algorithm
170
+ β”œβ”€β”€ separated1.wav
171
+ β”œβ”€β”€ separated2.wav
172
+ └── ...
173
+ ```
174
+ 2. **Upload the ZIP file** using the file uploader below
175
+ 3. **Select model and parameters**
176
+ 4. **Click "Process"** to run the evaluation
177
+ 5. **Download the results** as a ZIP file containing CSV files with PS/PM scores
178
+
179
+ ## Models available:
180
+ - **raw**: Raw waveform features (no model)
181
+ - **wavlm**: WavLM Large model (best overall performance)
182
+ - **wav2vec2**: Wav2Vec2 Large model
183
+ - **hubert**: HuBERT Large model
184
+ - **wavlm_base**: WavLM Base model (faster, good performance)
185
+ - **wav2vec2_base**: Wav2Vec2 Base model
186
+ - **hubert_base**: HuBERT Base model
187
+ - **wav2vec2_xlsr**: Wav2Vec2 XLSR-53 model (multilingual)
188
+ - **ast**: Audio Spectrogram Transformer
189
+ """)
190
+
191
+ with gr.Row():
192
+ with gr.Column():
193
+ file_input = gr.File(
194
+ label="Upload ZIP file with audio mixtures",
195
+ file_types=[".zip"],
196
+ type="filepath"
197
+ )
198
+
199
+ model_dropdown = gr.Dropdown(
200
+ choices=["raw", "wavlm", "wav2vec2", "hubert",
201
+ "wavlm_base", "wav2vec2_base", "hubert_base",
202
+ "wav2vec2_xlsr", "ast"],
203
+ value="wav2vec2_base",
204
+ label="Select embedding model"
205
+ )
206
+
207
+ layer_slider = gr.Slider(
208
+ minimum=0,
209
+ maximum=24,
210
+ step=1,
211
+ value=12,
212
+ label="Layer (leave at default for automatic selection)"
213
+ )
214
+
215
+ alpha_slider = gr.Slider(
216
+ minimum=0.0,
217
+ maximum=1.0,
218
+ step=0.1,
219
+ value=DEFAULT_ALPHA,
220
+ label="Diffusion maps alpha parameter"
221
+ )
222
+
223
+ process_btn = gr.Button("Process Audio Files", variant="primary")
224
+
225
+ with gr.Column():
226
+ output_file = gr.File(
227
+ label="Download Results (ZIP)",
228
+ type="filepath"
229
+ )
230
+ status_text = gr.Textbox(
231
+ label="Status",
232
+ lines=3,
233
+ max_lines=10
234
+ )
235
+
236
+ gr.Markdown("""
237
+ ## Output format:
238
+ The results ZIP will contain:
239
+ - `ps_scores_{model}.csv`: Perceptual Similarity scores for each speaker/source
240
+ - `pm_scores_{model}.csv`: Perceptual Matching scores for each speaker/source
241
+ - `params.json`: Experiment parameters
242
+ - `manifest_canonical.json`: Processed file manifest
243
+
244
+ ## Score interpretation:
245
+ - **PS (Perceptual Similarity)**: 0-1 score, higher is better. Measures how well the separated output matches the reference compared to other sources.
246
+ - **PM (Perceptual Matching)**: 0-1 score, higher is better. Measures robustness to audio distortions.
247
+
248
+ ## Notes:
249
+ - Processing may take several minutes depending on the audio length and model
250
+ - Audio files are automatically resampled to 16kHz
251
+ - The tool automatically matches outputs to references based on correlation
252
+ - For best results, ensure equal number of reference and output files
253
+
254
+ ## Citation:
255
+ If you use this tool in your research, please cite our paper (details coming soon).
256
+ """)
257
+
258
+ # Set up the processing
259
+ process_btn.click(
260
+ fn=process_audio_files,
261
+ inputs=[file_input, model_dropdown, layer_slider, alpha_slider],
262
+ outputs=[output_file, status_text]
263
+ )
264
+
265
+ # Add examples if you want
266
+ gr.Examples(
267
+ examples=[
268
+ # You can add example ZIP files here if you have them
269
+ ],
270
+ inputs=[file_input]
271
+ )
272
+
273
+ return demo
274
+
275
+ # Create and launch the app
276
+ if __name__ == "__main__":
277
+ demo = create_interface()
278
+ demo.launch()