ernestchu commited on
Commit
ded337a
1 Parent(s): c55d8da

add speech enhancement

Browse files
Files changed (2) hide show
  1. app.py +21 -7
  2. requirements.txt +1 -0
app.py CHANGED
@@ -3,9 +3,11 @@ import time
3
  from tsmnet import Stretcher
4
  import gradio as gr
5
  from gradio import processing_utils
6
- # import torch
 
7
  import torchaudio
8
  import yt_dlp
 
9
 
10
  model_root = './weights'
11
  yt_dl_dir = 'yt-audio'
@@ -58,7 +60,7 @@ def prepare_audio_file(rec, audio_file, yt_url):
58
  raise gr.Error('No audio found!')
59
 
60
 
61
- def run(rec, audio_file, yt_url, speed, model, start_time, end_time):
62
  audio_file = prepare_audio_file(rec, audio_file, yt_url)
63
 
64
  x, sr = torchaudio.load(audio_file)
@@ -67,8 +69,18 @@ def run(rec, audio_file, yt_url, speed, model, start_time, end_time):
67
 
68
  x = x[:, int(start_time * sr):int(end_time * sr)]
69
 
70
- if speed != 1:
71
- x = models[model](x, speed).cpu()
 
 
 
 
 
 
 
 
 
 
72
 
73
  torchaudio.save(audio_file, x, sr)
74
  return processing_utils.audio_from_file(audio_file)
@@ -86,16 +98,17 @@ with gr.Blocks() as demo:
86
  with gr.Column():
87
  with gr.Tab('From microphone'):
88
  rec_box = gr.Audio(label='Recording', source='microphone', type='filepath')
89
- with gr.Tab('From file'):
90
- audio_file_box = gr.Audio(label='Audio sample', type='filepath')
91
  with gr.Tab('From YouTube'):
92
  yt_url_box = gr.Textbox(label='YouTube URL', placeholder='https://youtu.be/q6EoRBvdVPQ')
 
 
 
93
 
94
  rec_box.change(lambda: [None] * 2, outputs=[audio_file_box, yt_url_box])
95
  audio_file_box.change(lambda: [None] * 2, outputs=[rec_box, yt_url_box])
96
  yt_url_box.input(lambda: [None] * 2, outputs=[rec_box, audio_file_box])
97
 
98
- speed_box = gr.Slider(label='Playback speed', minimum=0, maximum=2, value=1)
99
  with gr.Accordion('Fine-grained settings', open=False):
100
  with gr.Tab('Trim audio sample (sec)'):
101
  # gr.Markdown('### Trim audio sample (sec)')
@@ -117,6 +130,7 @@ with gr.Blocks() as demo:
117
  rec_box,
118
  audio_file_box,
119
  yt_url_box,
 
120
  speed_box,
121
  model_box,
122
  start_time_box,
 
3
  from tsmnet import Stretcher
4
  import gradio as gr
5
  from gradio import processing_utils
6
+ import torch
7
+ import numpy as np
8
  import torchaudio
9
  import yt_dlp
10
+ import noisereduce as nr
11
 
12
  model_root = './weights'
13
  yt_dl_dir = 'yt-audio'
 
60
  raise gr.Error('No audio found!')
61
 
62
 
63
+ def run(rec, audio_file, yt_url, denoise, speed, model, start_time, end_time):
64
  audio_file = prepare_audio_file(rec, audio_file, yt_url)
65
 
66
  x, sr = torchaudio.load(audio_file)
 
69
 
70
  x = x[:, int(start_time * sr):int(end_time * sr)]
71
 
72
+ if speed == 1:
73
+ torchaudio.save(audio_file, x, sr)
74
+ return processing_utils.audio_from_file(audio_file)
75
+
76
+ x = models[model](x, speed).cpu()
77
+
78
+ if denoise:
79
+ if len(x.shape) == 1: # mono
80
+ x = x[None]
81
+ x = x.numpy()
82
+ # perform noise reduction
83
+ x = torch.from_numpy(np.stack([nr.reduce_noise(y=y, sr=sr) for y in x]))
84
 
85
  torchaudio.save(audio_file, x, sr)
86
  return processing_utils.audio_from_file(audio_file)
 
98
  with gr.Column():
99
  with gr.Tab('From microphone'):
100
  rec_box = gr.Audio(label='Recording', source='microphone', type='filepath')
 
 
101
  with gr.Tab('From YouTube'):
102
  yt_url_box = gr.Textbox(label='YouTube URL', placeholder='https://youtu.be/q6EoRBvdVPQ')
103
+ with gr.Tab('From file'):
104
+ audio_file_box = gr.Audio(label='Audio sample', type='filepath')
105
+ denoise_box = gr.Checkbox(label='Speech enhancement (should be off for music)', value=True)
106
 
107
  rec_box.change(lambda: [None] * 2, outputs=[audio_file_box, yt_url_box])
108
  audio_file_box.change(lambda: [None] * 2, outputs=[rec_box, yt_url_box])
109
  yt_url_box.input(lambda: [None] * 2, outputs=[rec_box, audio_file_box])
110
 
111
+ speed_box = gr.Slider(label='Playback speed', minimum=0.25, maximum=2, value=1)
112
  with gr.Accordion('Fine-grained settings', open=False):
113
  with gr.Tab('Trim audio sample (sec)'):
114
  # gr.Markdown('### Trim audio sample (sec)')
 
130
  rec_box,
131
  audio_file_box,
132
  yt_url_box,
133
+ denoise_box,
134
  speed_box,
135
  model_box,
136
  start_time_box,
requirements.txt CHANGED
@@ -3,4 +3,5 @@ torchvision
3
  torchaudio
4
  yt-dlp
5
  wget
 
6
 
 
3
  torchaudio
4
  yt-dlp
5
  wget
6
+ noisereduce
7