Hendrik Schroeter commited on
Commit
7cbdca2
1 Parent(s): 480d487
Files changed (8) hide show
  1. .flake8 +17 -0
  2. .gitignore +144 -0
  3. README.md +3 -4
  4. app.py +279 -0
  5. packages.txt +1 -0
  6. pyproject.toml +10 -0
  7. requirements.txt +6 -0
  8. usage.md +8 -0
.flake8 ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ ignore = E203, E266, E501, W503
3
+ max-line-length = 100
4
+ import-order-style = google
5
+ application-import-names = flake8
6
+ select = B,C,E,F,W,T4,B9
7
+ exclude =
8
+ .tox,
9
+ .git,
10
+ __pycache__,
11
+ docs,
12
+ sbatch,
13
+ .venv,
14
+ *.pyc,
15
+ *.egg-info,
16
+ .cache,
17
+ .eggs
.gitignore ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Own stuff
2
+ *.wav
3
+ *.png
4
+ *.pdf
5
+ out/
6
+ export/
7
+ DeepFilterNet/poetry.lock
8
+
9
+ ### Rust gitignore ###
10
+
11
+ # Generated by Cargo
12
+ # will have compiled files and executables
13
+ debug/
14
+ target/
15
+
16
+ # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
17
+ # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
18
+ Cargo.lock
19
+
20
+ # These are backup files generated by rustfmt
21
+ **/*.rs.bk
22
+
23
+ ### Python gitignore ###
24
+
25
+ # Byte-compiled / optimized / DLL files
26
+ __pycache__/
27
+ *.py[cod]
28
+ *$py.class
29
+
30
+ # C extensions
31
+ *.so
32
+
33
+ # Distribution / packaging
34
+ .Python
35
+ build/
36
+ develop-eggs/
37
+ dist/
38
+ downloads/
39
+ eggs/
40
+ .eggs/
41
+ lib/
42
+ lib64/
43
+ parts/
44
+ sdist/
45
+ var/
46
+ wheels/
47
+ pip-wheel-metadata/
48
+ share/python-wheels/
49
+ *.egg-info/
50
+ .installed.cfg
51
+ *.egg
52
+ MANIFEST
53
+
54
+ # PyInstaller
55
+ # Usually these files are written by a python script from a template
56
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
57
+ *.manifest
58
+ *.spec
59
+
60
+ # Installer logs
61
+ pip-log.txt
62
+ pip-delete-this-directory.txt
63
+
64
+ # Unit test / coverage reports
65
+ typings
66
+ htmlcov/
67
+ .tox/
68
+ .nox/
69
+ .coverage
70
+ .coverage.*
71
+ .cache
72
+ nosetests.xml
73
+ coverage.xml
74
+ *.cover
75
+ .hypothesis/
76
+ .pytest_cache/
77
+
78
+ # Translations
79
+ *.mo
80
+ *.pot
81
+
82
+ # Django stuff:
83
+ *.log
84
+ local_settings.py
85
+ db.sqlite3
86
+
87
+ # Flask stuff:
88
+ instance/
89
+ .webassets-cache
90
+
91
+ # Scrapy stuff:
92
+ .scrapy
93
+
94
+ # Sphinx documentation
95
+ docs/_build/
96
+
97
+ # PyBuilder
98
+ target/
99
+
100
+ # Jupyter Notebook
101
+ .ipynb_checkpoints
102
+
103
+ # IPython
104
+ profile_default/
105
+ ipython_config.py
106
+
107
+ # pyenv
108
+ .python-version
109
+
110
+ # celery beat schedule file
111
+ celerybeat-schedule
112
+
113
+ # SageMath parsed files
114
+ *.sage.py
115
+
116
+ # Environments
117
+ .env
118
+ .venv
119
+ env/
120
+ venv/
121
+ ENV/
122
+ env.bak/
123
+ venv.bak/
124
+
125
+ # Spyder project settings
126
+ .spyderproject
127
+ .spyproject
128
+
129
+ # Rope project settings
130
+ .ropeproject
131
+
132
+ # mkdocs documentation
133
+ /site
134
+
135
+ # mypy
136
+ .mypy_cache/
137
+ .dmypy.json
138
+ dmypy.json
139
+
140
+ # Pyre type checker
141
+ .pyre/
142
+
143
+ # IDE
144
+ .idea
README.md CHANGED
@@ -1,10 +1,9 @@
1
  ---
2
- title: DeepFilterNet2
3
- emoji: 🐠
4
  colorFrom: gray
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 2.9.4
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
+ title: DeepFilterNet
3
+ emoji: 💩
4
  colorFrom: gray
5
+ colorTo: red
6
  sdk: gradio
 
7
  app_file: app.py
8
  pinned: false
9
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import tempfile
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import gradio
6
+ import gradio.inputs
7
+ import gradio.outputs
8
+ import markdown
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import torch
12
+ from loguru import logger
13
+ from torch import Tensor
14
+ from torchaudio.backend.common import AudioMetaData
15
+
16
+ from df import config
17
+ from df.enhance import enhance, init_df, load_audio, save_audio
18
+ from df.utils import resample
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ model, df, _ = init_df(config_allow_defaults=True)
22
+ model = model.to(device=device).eval()
23
+
24
+ NOISES = {
25
+ "None": None,
26
+ "Kitchen": "samples/dkitchen.wav",
27
+ "Living Room": "samples/dliving.wav",
28
+ "River": "samples/nriver.wav",
29
+ "Cafe": "samples/scafe.wav",
30
+ }
31
+
32
+
33
+ def mix_at_snr(clean, noise, snr, eps=1e-10):
34
+ """Mix clean and noise signal at a given SNR.
35
+
36
+ Args:
37
+ clean: 1D Tensor with the clean signal to mix.
38
+ noise: 1D Tensor of shape.
39
+ snr: Signal to noise ratio.
40
+
41
+ Returns:
42
+ clean: 1D Tensor with gain changed according to the snr.
43
+ noise: 1D Tensor with the combined noise channels.
44
+ mix: 1D Tensor with added clean and noise signals.
45
+
46
+ """
47
+ clean = torch.as_tensor(clean).mean(0, keepdim=True)
48
+ noise = torch.as_tensor(noise).mean(0, keepdim=True)
49
+ if noise.shape[1] < clean.shape[1]:
50
+ noise = noise.repeat((1, int(math.ceil(clean.shape[1] / noise.shape[1]))))
51
+ max_start = int(noise.shape[1] - clean.shape[1])
52
+ start = torch.randint(0, max_start, ()).item()
53
+ logger.debug(f"start: {start}, {clean.shape}")
54
+ noise = noise[:, start : start + clean.shape[1]]
55
+ E_speech = torch.mean(clean.pow(2)) + eps
56
+ E_noise = torch.mean(noise.pow(2))
57
+ K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps)
58
+ noise = noise / K
59
+ mixture = clean + noise
60
+ logger.debug("mixture: {mixture.shape}")
61
+ assert torch.isfinite(mixture).all()
62
+ max_m = mixture.abs().max()
63
+ if max_m > 1:
64
+ logger.warning(f"Clipping detected during mixing. Reducing gain by {1/max_m}")
65
+ clean, noise, mixture = clean / max_m, noise / max_m, mixture / max_m
66
+ return clean, noise, mixture
67
+
68
+
69
+ def load_audio_gradio(
70
+ audio_or_file: Union[None, str, Tuple[int, np.ndarray]], sr: int
71
+ ) -> Optional[Tuple[Tensor, AudioMetaData]]:
72
+ if audio_or_file is None:
73
+ return None
74
+ if isinstance(audio_or_file, str):
75
+ if audio_or_file.lower() == "none":
76
+ return None
77
+ # First try default format
78
+ audio, meta = load_audio(audio_or_file, sr)
79
+ else:
80
+ meta = AudioMetaData(-1, -1, -1, -1, "")
81
+ assert isinstance(audio_or_file, (tuple, list))
82
+ meta.sample_rate, audio_np = audio_or_file
83
+ # Gradio documentation says, the shape is [samples, 2], but apparently sometimes its not.
84
+ audio_np = audio_np.reshape(audio_np.shape[0], -1).T
85
+ if audio_np.dtype == np.int16:
86
+ audio_np = (audio_np / (1 << 15)).astype(np.float32)
87
+ elif audio_np.dtype == np.int32:
88
+ audio_np = (audio_np / (1 << 31)).astype(np.float32)
89
+ audio = resample(torch.from_numpy(audio_np), meta.sample_rate, sr)
90
+ return audio, meta
91
+
92
+
93
+ def demo_fn(
94
+ speech_rec: Union[str, Tuple[int, np.ndarray]], speech_upl: str, noise_type: str, snr: int
95
+ ):
96
+ sr = config("sr", 48000, int, section="df")
97
+ logger.info(
98
+ f"Got parameters speech_rec: {speech_rec}, speech_upl: {speech_upl}, noise: {noise_type}"
99
+ )
100
+ noise_fn = NOISES[noise_type]
101
+ meta = AudioMetaData(-1, -1, -1, -1, "")
102
+ if speech_rec is None and speech_upl is None:
103
+ sample, meta = load_audio("samples/p232_013_clean.wav", sr)
104
+ elif speech_upl is not None:
105
+ sample, meta = load_audio(speech_upl, sr)
106
+ else:
107
+ tmp = load_audio_gradio(speech_rec, sr)
108
+ assert tmp is not None
109
+ sample, meta = tmp
110
+ sample = sample[..., : 10 * meta.sample_rate] # limit to 10 seconds
111
+ logger.info(f"Loaded sample with shape {sample.shape}")
112
+ if noise_fn is not None:
113
+ noise, _ = load_audio(noise_fn, sr) # type: ignore
114
+ logger.info(f"Loaded noise with shape {noise.shape}")
115
+ _, _, sample = mix_at_snr(sample, noise, snr)
116
+ logger.info("Start denoising audio")
117
+ enhanced = enhance(model, df, sample)
118
+ logger.info("Denoising finished")
119
+ lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
120
+ lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
121
+ enhanced = enhanced * lim
122
+ # if meta.sample_rate != sr:
123
+ # enhanced = resample(enhanced, sr, meta.sample_rate)
124
+ # noisy = resample(noisy, sr, meta.sample_rate)
125
+ # sr = meta.sample_rate
126
+ noisy_fn = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
127
+ save_audio(noisy_fn, sample, sr)
128
+ enhanced_fn = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
129
+ save_audio(enhanced_fn, enhanced, sr)
130
+ logger.info(f"saved audios: {noisy_fn}, {enhanced_fn}")
131
+ return (
132
+ noisy_fn,
133
+ spec_figure(sample, sr=sr),
134
+ enhanced_fn,
135
+ spec_figure(enhanced, sr=sr),
136
+ )
137
+
138
+
139
+ def specshow(
140
+ spec,
141
+ ax=None,
142
+ title=None,
143
+ xlabel=None,
144
+ ylabel=None,
145
+ sr=48000,
146
+ n_fft=None,
147
+ hop=None,
148
+ t=None,
149
+ f=None,
150
+ vmin=-100,
151
+ vmax=0,
152
+ xlim=None,
153
+ ylim=None,
154
+ cmap="inferno",
155
+ ):
156
+ """Plots a spectrogram of shape [F, T]"""
157
+ spec_np = spec.cpu().numpy() if isinstance(spec, torch.Tensor) else spec
158
+ if ax is not None:
159
+ set_title = ax.set_title
160
+ set_xlabel = ax.set_xlabel
161
+ set_ylabel = ax.set_ylabel
162
+ set_xlim = ax.set_xlim
163
+ set_ylim = ax.set_ylim
164
+ else:
165
+ ax = plt
166
+ set_title = plt.title
167
+ set_xlabel = plt.xlabel
168
+ set_ylabel = plt.ylabel
169
+ set_xlim = plt.xlim
170
+ set_ylim = plt.ylim
171
+ if n_fft is None:
172
+ if spec.shape[0] % 2 == 0:
173
+ n_fft = spec.shape[0] * 2
174
+ else:
175
+ n_fft = (spec.shape[0] - 1) * 2
176
+ hop = hop or n_fft // 4
177
+ if t is None:
178
+ t = np.arange(0, spec_np.shape[-1]) * hop / sr
179
+ if f is None:
180
+ f = np.arange(0, spec_np.shape[0]) * sr // 2 / (n_fft // 2) / 1000
181
+ im = ax.pcolormesh(
182
+ t, f, spec_np, rasterized=True, shading="auto", vmin=vmin, vmax=vmax, cmap=cmap
183
+ )
184
+ if title is not None:
185
+ set_title(title)
186
+ if xlabel is not None:
187
+ set_xlabel(xlabel)
188
+ if ylabel is not None:
189
+ set_ylabel(ylabel)
190
+ if xlim is not None:
191
+ set_xlim(xlim)
192
+ if ylim is not None:
193
+ set_ylim(ylim)
194
+ return im
195
+
196
+
197
+ def spec_figure(
198
+ audio: torch.Tensor,
199
+ figsize=(15, 5),
200
+ colorbar=False,
201
+ colorbar_format=None,
202
+ figure=None,
203
+ return_im=False,
204
+ labels=True,
205
+ **kwargs,
206
+ ) -> plt.Figure:
207
+ audio = torch.as_tensor(audio)
208
+ if labels:
209
+ kwargs.setdefault("xlabel", "Time [s]")
210
+ kwargs.setdefault("ylabel", "Frequency [Hz]")
211
+ n_fft = kwargs.setdefault("n_fft", 1024)
212
+ hop = kwargs.setdefault("hop", 512)
213
+ w = torch.hann_window(n_fft, device=audio.device)
214
+ spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False)
215
+ spec = spec.div_(w.pow(2).sum())
216
+ spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10)
217
+ kwargs.setdefault("vmax", max(0.0, spec.max().item()))
218
+
219
+ if figure is None:
220
+ figure = plt.figure(figsize=figsize)
221
+ figure.set_tight_layout(True)
222
+ if spec.dim() > 2:
223
+ spec = spec.squeeze(0)
224
+ im = specshow(spec, **kwargs)
225
+ if colorbar:
226
+ ckwargs = {}
227
+ if "ax" in kwargs:
228
+ if colorbar_format is None:
229
+ if kwargs.get("vmin", None) is not None or kwargs.get("vmax", None) is not None:
230
+ colorbar_format = "%+2.0f dB"
231
+ ckwargs = {"ax": kwargs["ax"]}
232
+ plt.colorbar(im, format=colorbar_format, **ckwargs)
233
+ if return_im:
234
+ return im
235
+ return figure
236
+
237
+
238
+ inputs = [
239
+ gradio.inputs.Audio(
240
+ label="Record your own voice",
241
+ source="microphone",
242
+ type="numpy",
243
+ optional=True,
244
+ ),
245
+ gradio.inputs.Audio(
246
+ label="Alternative: Upload audio sample",
247
+ source="upload",
248
+ type="filepath",
249
+ optional=True,
250
+ ),
251
+ gradio.inputs.Dropdown(
252
+ label="Add background noise",
253
+ choices=list(NOISES.keys()),
254
+ default="None",
255
+ ),
256
+ gradio.inputs.Dropdown(
257
+ label="Noise Level (SNR)",
258
+ choices=[-5, 0, 10, 20],
259
+ default=10,
260
+ ),
261
+ ]
262
+ outputs = [
263
+ gradio.outputs.Audio(label="Noisy"),
264
+ gradio.outputs.Image(type="plot"),
265
+ gradio.outputs.Audio(label="Enhanced"),
266
+ gradio.outputs.Image(type="plot"),
267
+ ]
268
+ description = "This demo denoises audio files using DeepFilterNet. Try it with your own voice!"
269
+ iface = gradio.Interface(
270
+ fn=demo_fn,
271
+ title="DeepFilterNet2 Demo",
272
+ inputs=inputs,
273
+ outputs=outputs,
274
+ description=description,
275
+ layout="horizontal",
276
+ allow_flagging="never",
277
+ article=markdown.markdown(open("usage.md").read()),
278
+ )
279
+ iface.launch(cache_examples=False, debug=True)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
pyproject.toml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ line-length = 100
3
+ target-version = ["py37", "py38", "py39", "py310"]
4
+ include = '\.pyi?$'
5
+
6
+ [tool.isort]
7
+ profile = "black"
8
+ line_length = 100
9
+ skip_gitignore = true
10
+ known_first_party = ["df", "libdf", "libdfdata"]
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ deepfilternet==0.2.0
4
+ matplotlib
5
+ markdown
6
+ gradio
usage.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ **Usage:**
2
+
3
+ This demo takes a audio sample and enhances it using DeepFilterNet2.
4
+ You can either record a speech sample or alternatively provide one via upload.
5
+ Furthermore, you may optionally add some additional background noise to the input sample.
6
+ If no samples are provided, a default will be used.
7
+
8
+ DeepFilterNet2 [(link)](https://github.com/Rikorose/DeepFilterNet) is used to denoise the noisy mixture.