lllindsey0615 commited on
Commit
29ead22
·
1 Parent(s): f489d22

solving tensor dimension mismatch issue

Browse files
Files changed (1) hide show
  1. app.py +14 -2
app.py CHANGED
@@ -47,8 +47,20 @@ def separate_all_stems(audio_file_path: str, model_name: str):
47
 
48
  sr = signal.sample_rate
49
 
50
- waveform = signal.audio_data.float() # [channels, samples]
51
- waveform = waveform.unsqueeze(0) # [1, channels, samples]
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  with torch.no_grad():
54
  stems_batch = apply_model(
 
47
 
48
  sr = signal.sample_rate
49
 
50
+ # Ensure audio_data is a torch.Tensor
51
+ audio = signal.audio_data
52
+ if isinstance(audio, np.ndarray):
53
+ audio = torch.from_numpy(audio)
54
+
55
+ audio = audio.float() # [channels, samples] or [channels, samples, ?]
56
+
57
+ # Remove extra trailing dimensions
58
+ if audio.ndim > 2:
59
+ audio = audio.squeeze()
60
+
61
+ # Final shape: [1, channels, samples]
62
+ waveform = audio.unsqueeze(0)
63
+
64
 
65
  with torch.no_grad():
66
  stems_batch = apply_model(