wrice commited on
Commit
fc81f0f
1 Parent(s): e16ac0a

add app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
@@ -5,10 +6,22 @@ import torchaudio
5
  from denoisers import WaveUNetModel
6
  from tqdm import tqdm
7
 
8
- MODEL = WaveUNetModel.from_pretrained("wrice/waveunet-vctk-24khz")
9
 
10
 
11
- def denoise(inputs):
 
 
 
 
 
 
 
 
 
 
 
 
12
  sr, audio = inputs
13
  audio = torch.from_numpy(audio)[None]
14
  audio = audio / 32768.0
@@ -16,10 +29,10 @@ def denoise(inputs):
16
  print(f"Audio shape: {audio.shape}")
17
  print(f"Sample rate: {sr}")
18
 
19
- if sr != MODEL.config.sample_rate:
20
- audio = torchaudio.functional.resample(audio, sr, MODEL.config.sample_rate)
21
 
22
- chunk_size = MODEL.config.max_length
23
 
24
  padding = abs(audio.size(-1) % chunk_size - chunk_size)
25
  padded = torch.nn.functional.pad(audio, (0, padding))
@@ -28,7 +41,7 @@ def denoise(inputs):
28
  for i in tqdm(range(0, padded.shape[-1], chunk_size)):
29
  audio_chunk = padded[:, i : i + chunk_size]
30
  with torch.no_grad():
31
- clean_chunk = MODEL(audio_chunk[None]).logits
32
  clean.append(clean_chunk.squeeze(0))
33
 
34
  denoised = torch.concat(clean).flatten()[: audio.shape[-1]].clamp(-1.0, 1.0)
@@ -36,8 +49,8 @@ def denoise(inputs):
36
 
37
  print(f"Denoised shape: {denoised.shape}")
38
 
39
- return MODEL.config.sample_rate, denoised
40
 
41
 
42
- iface = gr.Interface(fn=denoise, inputs="audio", outputs="audio")
43
- iface.launch()
 
1
+ """Gradio demo for denoisers."""
2
  import gradio as gr
3
  import numpy as np
4
  import torch
 
6
  from denoisers import WaveUNetModel
7
  from tqdm import tqdm
8
 
9
+ MODELS = ["wrice/waveunet-vctk-48khz", "wrice/waveunet-vctk-24khz"]
10
 
11
 
12
+ def main():
13
+ """Main."""
14
+ iface = gr.Interface(
15
+ fn=denoise,
16
+ inputs=[gr.Dropdown(choices=MODELS, default=MODELS[0]), "audio"],
17
+ outputs="audio",
18
+ )
19
+ iface.launch()
20
+
21
+
22
+ def denoise(model_name, inputs):
23
+ """Denoise audio."""
24
+ model = WaveUNetModel.from_pretrained(model_name)
25
  sr, audio = inputs
26
  audio = torch.from_numpy(audio)[None]
27
  audio = audio / 32768.0
 
29
  print(f"Audio shape: {audio.shape}")
30
  print(f"Sample rate: {sr}")
31
 
32
+ if sr != model.config.sample_rate:
33
+ audio = torchaudio.functional.resample(audio, sr, model.config.sample_rate)
34
 
35
+ chunk_size = model.config.max_length
36
 
37
  padding = abs(audio.size(-1) % chunk_size - chunk_size)
38
  padded = torch.nn.functional.pad(audio, (0, padding))
 
41
  for i in tqdm(range(0, padded.shape[-1], chunk_size)):
42
  audio_chunk = padded[:, i : i + chunk_size]
43
  with torch.no_grad():
44
+ clean_chunk = model(audio_chunk[None]).logits
45
  clean.append(clean_chunk.squeeze(0))
46
 
47
  denoised = torch.concat(clean).flatten()[: audio.shape[-1]].clamp(-1.0, 1.0)
 
49
 
50
  print(f"Denoised shape: {denoised.shape}")
51
 
52
+ return model.config.sample_rate, denoised
53
 
54
 
55
+ if __name__ == "__main__":
56
+ main()