Commit
·
29ead22
1
Parent(s):
f489d22
solving tensor dimension mismatch issue
Browse files
app.py
CHANGED
|
@@ -47,8 +47,20 @@ def separate_all_stems(audio_file_path: str, model_name: str):
|
|
| 47 |
|
| 48 |
sr = signal.sample_rate
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
with torch.no_grad():
|
| 54 |
stems_batch = apply_model(
|
|
|
|
| 47 |
|
| 48 |
sr = signal.sample_rate
|
| 49 |
|
| 50 |
+
# Ensure audio_data is a torch.Tensor
|
| 51 |
+
audio = signal.audio_data
|
| 52 |
+
if isinstance(audio, np.ndarray):
|
| 53 |
+
audio = torch.from_numpy(audio)
|
| 54 |
+
|
| 55 |
+
audio = audio.float() # [channels, samples] or [channels, samples, ?]
|
| 56 |
+
|
| 57 |
+
# Remove extra trailing dimensions
|
| 58 |
+
if audio.ndim > 2:
|
| 59 |
+
audio = audio.squeeze()
|
| 60 |
+
|
| 61 |
+
# Final shape: [1, channels, samples]
|
| 62 |
+
waveform = audio.unsqueeze(0)
|
| 63 |
+
|
| 64 |
|
| 65 |
with torch.no_grad():
|
| 66 |
stems_batch = apply_model(
|