Shokoufehhh commited on
Commit
a297d3b
·
verified ·
1 Parent(s): b034910

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -3
app.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import torchaudio
3
  from sgmse.model import ScoreModel
4
  import gradio as gr
 
5
 
6
  # Load the pre-trained model
7
  model = ScoreModel.load_from_checkpoint("pretrained_checkpoints/speech_enhancement/train_vb_29nqe0uh_epoch=115.ckpt")
@@ -11,12 +12,34 @@ def enhance_speech(audio_file):
11
  noisy, sr = torchaudio.load(audio_file)
12
  noisy = noisy.unsqueeze(0) # Add fake batch dimension if needed
13
 
14
- # Run the speech enhancement model
15
- enhanced = model.predict(noisy)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Save the enhanced audio
18
  output_file = 'enhanced_output.wav'
19
- torchaudio.save(output_file, enhanced.cpu().squeeze(0), sr)
20
 
21
  return output_file
22
 
 
2
  import torchaudio
3
  from sgmse.model import ScoreModel
4
  import gradio as gr
5
+ from sgmse.util.other import pad_spec
6
 
7
  # Load the pre-trained model
8
  model = ScoreModel.load_from_checkpoint("pretrained_checkpoints/speech_enhancement/train_vb_29nqe0uh_epoch=115.ckpt")
 
12
  noisy, sr = torchaudio.load(audio_file)
13
  noisy = noisy.unsqueeze(0) # Add fake batch dimension if needed
14
 
15
+ if sr != target_sr:
16
+ y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr))
17
+
18
+ T_orig = y.size(1)
19
+
20
+ # Normalize
21
+ norm_factor = y.abs().max()
22
+ y = y / norm_factor
23
+
24
+ # Prepare DNN input
25
+ Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args.device))), 0)
26
+ Y = pad_spec(Y, mode=pad_mode)
27
+
28
+ # Reverse sampling
29
+ sampler = model.get_pc_sampler(
30
+ 'reverse_diffusion', args.corrector, Y.to(args.device), N=args.N,
31
+ corrector_steps=args.corrector_steps, snr=args.snr)
32
+ sample, _ = sampler()
33
+
34
+ # Backward transform in time domain
35
+ x_hat = model.to_audio(sample.squeeze(), T_orig)
36
+
37
+ # Renormalize
38
+ x_hat = x_hat * norm_factor
39
 
40
  # Save the enhanced audio
41
  output_file = 'enhanced_output.wav'
42
+ torchaudio.save(output_file, x_hat.cpu().numpy(), sr)
43
 
44
  return output_file
45