arcan3 commited on
Commit
fc5ed00
1 Parent(s): e160858

adde revision

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -34
  2. .gitignore +163 -0
  3. Languages/ben/G_100000.pth +3 -0
  4. Languages/ben/config.json +3 -0
  5. Languages/ben/vocab.txt +3 -0
  6. Languages/ell/G_100000.pth +3 -0
  7. Languages/ell/config.json +3 -0
  8. Languages/ell/vocab.txt +3 -0
  9. Languages/fra/G_100000.pth +3 -0
  10. Languages/fra/config.json +3 -0
  11. Languages/fra/vocab.txt +3 -0
  12. Languages/guj/G_100000.pth +3 -0
  13. Languages/guj/config.json +3 -0
  14. Languages/guj/vocab.txt +3 -0
  15. Languages/hin/G_100000.pth +3 -0
  16. Languages/hin/config.json +3 -0
  17. Languages/hin/vocab.txt +3 -0
  18. Languages/nld/G_100000.pth +3 -0
  19. Languages/nld/config.json +3 -0
  20. Languages/nld/vocab.txt +3 -0
  21. Languages/pol/G_100000.pth +3 -0
  22. Languages/pol/config.json +3 -0
  23. Languages/pol/vocab.txt +3 -0
  24. app.py +317 -0
  25. aux_files/uroman.pl +3 -0
  26. configurations/__init__.py +0 -0
  27. configurations/get_constants.py +176 -0
  28. configurations/get_hyperparameters.py +19 -0
  29. df/__init__.py +3 -0
  30. df/checkpoint.py +213 -0
  31. df/config.py +266 -0
  32. df/deepfilternet2.py +453 -0
  33. df/enhance.py +333 -0
  34. df/logger.py +212 -0
  35. df/model.py +24 -0
  36. df/modules.py +956 -0
  37. df/multiframe.py +329 -0
  38. df/utils.py +230 -0
  39. libdf/__init__.py +3 -0
  40. libdf/__init__.pyi +57 -0
  41. libdf/py.typed +0 -0
  42. model_weights/voice_enhance/checkpoints/model_96.ckpt.best +3 -0
  43. model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt +3 -0
  44. model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt.txt +3 -0
  45. model_weights/voiceover/freevc-24.json +3 -0
  46. model_weights/voiceover/freevc-24.pth +3 -0
  47. model_weights/wavlm_models/WavLM-Large.pt +3 -0
  48. model_weights/wavlm_models/WavLM-Large.pt.txt +3 -0
  49. nnet/__init__.py +0 -0
  50. nnet/attentions.py +300 -0
.gitattributes CHANGED
@@ -1,35 +1,6 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.json filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  *.pth filter=lfs diff=lfs merge=lfs -text
3
+ *.txt filter=lfs diff=lfs merge=lfs -text
4
+ *.pt* filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt* filter=lfs diff=lfs merge=lfs -text
6
+ *.pl filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ Temp_Audios/
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ # For a library or package, you might want to ignore these files since the code is
89
+ # intended to run in multiple environments; otherwise, check them in:
90
+ # .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # poetry
100
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
102
+ # commonly ignored for libraries.
103
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104
+ #poetry.lock
105
+
106
+ # pdm
107
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108
+ #pdm.lock
109
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110
+ # in version control.
111
+ # https://pdm.fming.dev/#use-with-ide
112
+ .pdm.toml
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+ *.ini
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
Languages/ben/G_100000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8c098eab2e5e378fc52bec57683839bbc641b2241033dab17174f6e37db29a4
3
+ size 145512166
Languages/ben/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49
3
+ size 1887
Languages/ben/vocab.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7085f1a1f6040b4da0ac55bb3ff91b77229d1ed14f7d86df2b23676a1a2cb81b
3
+ size 268
Languages/ell/G_100000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75bfa237f0fe859b34c4340bc7dccd944678cf9984bce5b5a82e2c90ca268db8
3
+ size 145504497
Languages/ell/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49
3
+ size 1887
Languages/ell/vocab.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c53d89f446eba9d061b510e31900d235ef0e021e44a790978dae5a4350a4013
3
+ size 164
Languages/fra/G_100000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63725b5a9201548b2247af02bd69a059335bddf52c1b858dbe38a43a40478bd7
3
+ size 145489135
Languages/fra/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49
3
+ size 1887
Languages/fra/vocab.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b57f0f246b488fe914508d82a8607e1aea357beb0f801069b39bfeb3a4c0d47
3
+ size 104
Languages/guj/G_100000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:427ac3c74f61be494b389cae7d771311d0bcf576f4e2f1b22f257539e26e323a
3
+ size 145501427
Languages/guj/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49
3
+ size 1887
Languages/guj/vocab.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:611d4c5d7ba4bce727c1154277aea43df7a534e22e523877d1885a36727d63c3
3
+ size 232
Languages/hin/G_100000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f1d5e47edd7368ff40ff5673ddfc606ea713e785420d26c2da396b555458d3b
3
+ size 145510619
Languages/hin/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49
3
+ size 1887
Languages/hin/vocab.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eea03474615c78a1c42d1299c345b6865421c07485544ac3361bff472e5005ac
3
+ size 266
Languages/nld/G_100000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b09e9917b07f06dd911045c8fc8738594b4c4d65c55223c46335093a4904816
3
+ size 145486855
Languages/nld/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49
3
+ size 1887
Languages/nld/vocab.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d7e65d89b8be768ac4a1e53643aebfe42830b82910bfed87f13904e2c5292a4
3
+ size 94
Languages/pol/G_100000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d6f4a9de92a6eb15bca8cb01826d8a9938ab6fb2c04a1c13a06d1d170c88ba6
3
+ size 145490647
Languages/pol/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49
3
+ size 1887
Languages/pol/vocab.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5514b17eb0cd17950849e3f2f2f22a6c7c2d18f2729f5b2fbfc2f2e5f035dc4a
3
+ size 103
app.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # load the libraries for the application
3
+ # -------------------------------------------
4
+ import os
5
+ import re
6
+ import nltk
7
+ import torch
8
+ import librosa
9
+ import tempfile
10
+ import subprocess
11
+
12
+ import gradio as gr
13
+
14
+ from scipy.io import wavfile
15
+ from nnet import utils, commons
16
+ from transformers import pipeline
17
+ from scipy.io.wavfile import write
18
+ from faster_whisper import WhisperModel
19
+ from nnet.models import SynthesizerTrn as vitsTRN
20
+ from nnet.models_vc import SynthesizerTrn as freeTRN
21
+ from nnet.mel_processing import mel_spectrogram_torch
22
+ from configurations.get_constants import constantConfig
23
+
24
+ from speaker_encoder.voice_encoder import SpeakerEncoder
25
+
26
+ from df.enhance import enhance, init_df, load_audio, save_audio
27
+ from configurations.get_hyperparameters import hyperparameterConfig
28
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
29
+
30
+ nltk.download('punkt')
31
+ from nltk.tokenize import sent_tokenize
32
+
33
+ # making the FreeVC function
34
+ # ---------------------------------
35
+ class FreeVCModel:
36
+ def __init__(self, config, ptfile, speaker_model, wavLM_model, device='cpu'):
37
+ self.hps = utils.get_hparams_from_file(config)
38
+
39
+ self.net_g = freeTRN(
40
+ self.hps.data.filter_length // 2 + 1,
41
+ self.hps.train.segment_size // self.hps.data.hop_length,
42
+ **self.hps.model
43
+ ).to(hyperparameters.device)
44
+ _ = self.net_g.eval()
45
+ _ = utils.load_checkpoint(ptfile, self.net_g, None, True)
46
+
47
+ self.cmodel = utils.get_cmodel(device, wavLM_model)
48
+
49
+ if self.hps.model.use_spk:
50
+ self.smodel = SpeakerEncoder(speaker_model)
51
+
52
+ def convert(self, src, tgt):
53
+ fs_src, src_audio = src
54
+ fs_tgt, tgt_audio = tgt
55
+
56
+ src = f"{constants.temp_audio_folder}/src.wav"
57
+ tgt = f"{constants.temp_audio_folder}/tgt.wav"
58
+ out = f"{constants.temp_audio_folder}/cnvr.wav"
59
+ with torch.no_grad():
60
+ wavfile.write(tgt, fs_tgt, tgt_audio)
61
+ wav_tgt, _ = librosa.load(tgt, sr=self.hps.data.sampling_rate)
62
+ wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
63
+ if self.hps.model.use_spk:
64
+ g_tgt = self.smodel.embed_utterance(wav_tgt)
65
+ g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(hyperparameters.device.type)
66
+ else:
67
+ wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(hyperparameters.device.type)
68
+ mel_tgt = mel_spectrogram_torch(
69
+ wav_tgt,
70
+ self.hps.data.filter_length,
71
+ self.hps.data.n_mel_channels,
72
+ self.hps.data.sampling_rate,
73
+ self.hps.data.hop_length,
74
+ self.hps.data.win_length,
75
+ self.hps.data.mel_fmin,
76
+ self.hps.data.mel_fmax,
77
+ )
78
+ wavfile.write(src, fs_src, src_audio)
79
+ wav_src, _ = librosa.load(src, sr=self.hps.data.sampling_rate)
80
+ wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(hyperparameters.device.type)
81
+ c = utils.get_content(self.cmodel, wav_src)
82
+
83
+ if self.hps.model.use_spk:
84
+ audio = self.net_g.infer(c, g=g_tgt)
85
+ else:
86
+ audio = self.net_g.infer(c, mel=mel_tgt)
87
+ audio = audio[0][0].data.cpu().float().numpy()
88
+ write(out, 24000, audio)
89
+
90
+ return out
91
+
92
+ # load the system configurations
93
+ constants = constantConfig()
94
+ hyperparameters = hyperparameterConfig()
95
+
96
+ # load the models
97
+ model, df_state, _ = init_df(hyperparameters.voice_enhacing_model, config_allow_defaults=True) # voice enhancing model
98
+ stt_model = WhisperModel(hyperparameters.stt_model, device=hyperparameters.device.type, compute_type="float32") #speech to text model
99
+
100
+ trans_model = AutoModelForSeq2SeqLM.from_pretrained(constants.model_name_dict[hyperparameters.nllb_model], torch_dtype=torch.bfloat16).to(hyperparameters.device)
101
+ trans_tokenizer = AutoTokenizer.from_pretrained(constants.model_name_dict[hyperparameters.nllb_model])
102
+
103
+ modelConvertSpeech = FreeVCModel(config=hyperparameters.text2speech_config, ptfile=hyperparameters.text2speech_model,
104
+ speaker_model=hyperparameters.text2speech_encoder, wavLM_model=hyperparameters.wavlm_model,
105
+ device=hyperparameters.device.type)
106
+
107
+ # download the language model if doesn't existing
108
+ # ----------------------------------------------------
109
+ def download(lang, lang_directory):
110
+
111
+ if not os.path.exists(f"{lang_directory}/{lang}"):
112
+ cmd = ";".join([
113
+ f"wget {constants.language_download_web}/{lang}.tar.gz -O {lang_directory}/{lang}.tar.gz",
114
+ f"tar zxvf {lang_directory}/{lang}.tar.gz -C {lang_directory}"
115
+ ])
116
+ subprocess.check_output(cmd, shell=True)
117
+ try:
118
+ os.remove(f"{lang_directory}/{lang}.tar.gz")
119
+ except:
120
+ pass
121
+ return f"{lang_directory}/{lang}"
122
+
123
+ def preprocess_char(text, lang=None):
124
+ """
125
+ Special treatement of characters in certain languages
126
+ """
127
+ if lang == 'ron':
128
+ text = text.replace("ț", "ţ")
129
+ return text
130
+
131
+ def preprocess_text(txt, text_mapper, hps, uroman_dir=None, lang=None):
132
+ txt = preprocess_char(txt, lang=lang)
133
+ is_uroman = hps.data.training_files.split('.')[-1] == 'uroman'
134
+ if is_uroman:
135
+ txt = text_mapper.uromanize(txt, f'{uroman_dir}/bin/uroman.pl')
136
+
137
+ txt = txt.lower()
138
+ txt = text_mapper.filter_oov(txt)
139
+ return txt
140
+
141
+ def detect_language(text,LID):
142
+ predictions = LID.predict(text)
143
+ detected_lang_code = predictions[0][0].replace("__label__", "")
144
+ return detected_lang_code
145
+
146
+ # text to speech
147
+ class TextMapper(object):
148
+ def __init__(self, vocab_file):
149
+ self.symbols = [x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()]
150
+ self.SPACE_ID = self.symbols.index(" ")
151
+ self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
152
+ self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
153
+
154
+ def text_to_sequence(self, text, cleaner_names):
155
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
156
+ Args:
157
+ text: string to convert to a sequence
158
+ cleaner_names: names of the cleaner functions to run the text through
159
+ Returns:
160
+ List of integers corresponding to the symbols in the text
161
+ '''
162
+ sequence = []
163
+ clean_text = text.strip()
164
+ for symbol in clean_text:
165
+ symbol_id = self._symbol_to_id[symbol]
166
+ sequence += [symbol_id]
167
+ return sequence
168
+
169
+ def uromanize(self, text, uroman_pl):
170
+ with tempfile.NamedTemporaryFile() as tf, \
171
+ tempfile.NamedTemporaryFile() as tf2:
172
+ with open(tf.name, "w") as f:
173
+ f.write("\n".join([text]))
174
+ cmd = f"perl " + uroman_pl
175
+ cmd += f" -l xxx "
176
+ cmd += f" < {tf.name} > {tf2.name}"
177
+ os.system(cmd)
178
+ outtexts = []
179
+ with open(tf2.name) as f:
180
+ for line in f:
181
+ line = re.sub(r"\s+", " ", line).strip()
182
+ outtexts.append(line)
183
+ outtext = outtexts[0]
184
+ return outtext
185
+
186
+ def get_text(self, text, hps):
187
+ text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
188
+ if hps.data.add_blank:
189
+ text_norm = commons.intersperse(text_norm, 0)
190
+ text_norm = torch.LongTensor(text_norm)
191
+ return text_norm
192
+
193
+ def filter_oov(self, text):
194
+ val_chars = self._symbol_to_id
195
+ txt_filt = "".join(list(filter(lambda x: x in val_chars, text)))
196
+ return txt_filt
197
+
198
+ def speech_to_text(audio_file):
199
+ try:
200
+ fs, audio = audio_file
201
+ wavfile.write(constants.input_speech_file, fs, audio)
202
+ audio0, _ = load_audio(constants.input_speech_file, sr=df_state.sr())
203
+
204
+ # Enhance the SNR of the audio
205
+ enhanced = enhance(model, df_state, audio0)
206
+ save_audio(constants.enhanced_speech_file, enhanced, df_state.sr())
207
+
208
+ segments, info = stt_model.transcribe(constants.enhanced_speech_file)
209
+
210
+ speech_text = ''
211
+ for segment in segments:
212
+ speech_text = f'{speech_text}{segment.text}'
213
+ try:
214
+ source_lang_nllb = [k for k, v in constants.flores_codes_to_tts_codes.items() if v[:2] == info.language][0]
215
+ except:
216
+ source_lang_nllb = 'language cant be determined, select manually'
217
+
218
+ # text translation
219
+ return speech_text, gr.Dropdown.update(value=source_lang_nllb)
220
+ except:
221
+ return '', gr.Dropdown.update(value='English')
222
+
223
+ # Text tp speech
224
+ def text_to_speech(text, target_lang):
225
+ txt = text
226
+
227
+ # LANG = get_target_tts_lang(target_lang)
228
+ LANG = constants.flores_codes_to_tts_codes[target_lang]
229
+ ckpt_dir = download(LANG, lang_directory=constants.language_directory)
230
+
231
+ vocab_file = f"{ckpt_dir}/{constants.language_vocab_text}"
232
+ config_file = f"{ckpt_dir}/{constants.language_vocab_configuration}"
233
+ hps = utils.get_hparams_from_file(config_file)
234
+ text_mapper = TextMapper(vocab_file)
235
+ net_g = vitsTRN(
236
+ len(text_mapper.symbols),
237
+ hps.data.filter_length // 2 + 1,
238
+ hps.train.segment_size // hps.data.hop_length,
239
+ **hps.model)
240
+ net_g.to(hyperparameters.device)
241
+ _ = net_g.eval()
242
+
243
+ g_pth = f"{ckpt_dir}/{constants.language_vocab_model}"
244
+
245
+ _ = utils.load_checkpoint(g_pth, net_g, None)
246
+
247
+ txt = preprocess_text(txt, text_mapper, hps, lang=LANG, uroman_dir=constants.uroman_directory)
248
+ stn_tst = text_mapper.get_text(txt, hps)
249
+ with torch.no_grad():
250
+ x_tst = stn_tst.unsqueeze(0).to(hyperparameters.device)
251
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(hyperparameters.device)
252
+ hyp = net_g.infer(
253
+ x_tst, x_tst_lengths, noise_scale=.667,
254
+ noise_scale_w=0.8, length_scale=1.0
255
+ )[0][0,0].cpu().float().numpy()
256
+
257
+ return hps.data.sampling_rate, hyp
258
+
259
+ def translation(audio, text, source_lang_nllb, target_code_nllb, output_type, sentence_mode):
260
+ target_code = constants.flores_codes[target_code_nllb]
261
+ translator = pipeline('translation', model=trans_model, tokenizer=trans_tokenizer, src_lang=source_lang_nllb, tgt_lang=target_code, device=hyperparameters.device)
262
+
263
+ # output = translator(text, max_length=400)[0]['translation_text']
264
+ if sentence_mode == "Sentence-wise":
265
+ sentences = sent_tokenize(text)
266
+ translated_sentences = []
267
+ for sentence in sentences:
268
+ translated_sentence = translator(sentence, max_length=400)[0]['translation_text']
269
+ translated_sentences.append(translated_sentence)
270
+ output = ' '.join(translated_sentences)
271
+ else:
272
+ output = translator(text, max_length=1024)[0]['translation_text']
273
+
274
+ # get the text to speech
275
+ fs_out, audio_out = text_to_speech(output, target_code_nllb)
276
+
277
+ if output_type == 'own voice':
278
+ out_file = modelConvertSpeech.convert((fs_out, audio_out), audio)
279
+ return output, out_file
280
+
281
+ wavfile.write(constants.text2speech_wavfile, fs_out, audio_out)
282
+ return output, constants.text2speech_wavfile
283
+
284
+ with gr.Blocks(title = "Octopus Translation App") as octopus_translator:
285
+ with gr.Row():
286
+ audio_file = gr.Audio(source="microphone")
287
+
288
+ with gr.Row():
289
+ input_text = gr.Textbox(label="Input text")
290
+ source_language = gr.Dropdown(list(constants.flores_codes.keys()), value='English', label='Source (Autoselected)', interactive=True)
291
+
292
+ with gr.Row():
293
+ output_text = gr.Textbox(label='Translated text')
294
+ target_language = gr.Dropdown(list(constants.flores_codes.keys()), value='German', label='Target', interactive=True)
295
+
296
+
297
+ with gr.Row():
298
+ output_speech = gr.Audio(label='Translated speech')
299
+ translate_button = gr.Button('Translate')
300
+
301
+
302
+ with gr.Row():
303
+ enhance_audio = gr.Radio(['yes', 'no'], value='yes', label='Enhance input voice', interactive=True)
304
+ input_type = gr.Radio(['Whole text', 'Sentence-wise'],value='Sentence-wise', label="Translation Mode", interactive=True)
305
+ output_audio_type = gr.Radio(['standard speaker', 'voice transfer'], value='voice transfer', label='Enhance output voice', interactive=True)
306
+
307
+ audio_file.change(speech_to_text,
308
+ inputs=[audio_file],
309
+ outputs=[input_text, source_language])
310
+
311
+ translate_button.click(translation,
312
+ inputs=[audio_file, input_text,
313
+ source_language, target_language,
314
+ output_audio_type, input_type],
315
+ outputs=[output_text, output_speech])
316
+
317
+ octopus_translator.launch(share=False)
aux_files/uroman.pl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ceece2c05343e8bc3b1a7cdc8cecd530af94a7928013c0e4224fd5c729fb29a
3
+ size 5347
configurations/__init__.py ADDED
File without changes
configurations/get_constants.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ class constantConfig:
4
+ def __init__(self):
5
+ self.flores_codes={'Acehnese (Arabic script)': 'ace_Arab',
6
+ 'Acehnese (Latin script)': 'ace_Latn',
7
+ 'Mesopotamian Arabic': 'acm_Arab',
8
+ 'Ta’izzi-Adeni Arabic': 'acq_Arab',
9
+ 'Tunisian Arabic': 'aeb_Arab',
10
+ 'Afrikaans': 'afr_Latn',
11
+ 'South Levantine Arabic': 'ajp_Arab',
12
+ 'Akan': 'aka_Latn',
13
+ 'Amharic': 'amh_Ethi',
14
+ 'North Levantine Arabic': 'apc_Arab',
15
+ 'Modern Standard Arabic': 'arb_Arab',
16
+ 'Modern Standard Arabic (Romanized)': 'arb_Latn',
17
+ 'Najdi Arabic': 'ars_Arab',
18
+ 'Moroccan Arabic': 'ary_Arab',
19
+ 'Egyptian Arabic': 'arz_Arab',
20
+ 'Assamese': 'asm_Beng',
21
+ 'Asturian': 'ast_Latn',
22
+ 'Awadhi': 'awa_Deva',
23
+ 'Central Aymara': 'ayr_Latn',
24
+ 'South Azerbaijani': 'azb_Arab',
25
+ 'North Azerbaijani': 'azj_Latn',
26
+ 'Bashkir': 'bak_Cyrl',
27
+ 'Bambara': 'bam_Latn',
28
+ 'Balinese': 'ban_Latn',
29
+ 'Belarusian': 'bel_Cyrl',
30
+ 'Bemba': 'bem_Latn',
31
+ 'Bengali': 'ben_Beng',
32
+ 'Bhojpuri': 'bho_Deva',
33
+ 'Banjar (Arabic script)': 'bjn_Arab',
34
+ 'Banjar (Latin script)': 'bjn_Latn',
35
+ 'Standard Tibetan': 'bod_Tibt',
36
+ 'Bosnian': 'bos_Latn',
37
+ 'Buginese': 'bug_Latn',
38
+ 'Bulgarian': 'bul_Cyrl',
39
+ 'Catalan': 'cat_Latn',
40
+ 'Cebuano': 'ceb_Latn',
41
+ 'Czech': 'ces_Latn',
42
+ 'Chokwe': 'cjk_Latn',
43
+ 'Central Kurdish': 'ckb_Arab',
44
+ 'Crimean Tatar': 'crh_Latn',
45
+ 'Welsh': 'cym_Latn',
46
+ 'Danish': 'dan_Latn',
47
+ 'German': 'deu_Latn',
48
+ 'Southwestern Dinka': 'dik_Latn',
49
+ 'Dyula': 'dyu_Latn',
50
+ 'Dzongkha': 'dzo_Tibt',
51
+ 'Greek': 'ell_Grek',
52
+ 'English': 'eng_Latn',
53
+ 'Esperanto': 'epo_Latn',
54
+ 'Estonian': 'est_Latn',
55
+ 'Basque': 'eus_Latn',
56
+ 'Ewe': 'ewe_Latn',
57
+ 'Faroese': 'fao_Latn',
58
+ 'Fijian': 'fij_Latn',
59
+ 'Finnish': 'fin_Latn',
60
+ 'Fon': 'fon_Latn',
61
+ 'French': 'fra_Latn',
62
+ 'Friulian': 'fur_Latn',
63
+ 'Nigerian Fulfulde': 'fuv_Latn',
64
+ 'Scottish Gaelic': 'gla_Latn',
65
+ 'Irish': 'gle_Latn',
66
+ 'Galician': 'glg_Latn',
67
+ 'Guarani': 'grn_Latn',
68
+ 'Gujarati': 'guj_Gujr',
69
+ 'Haitian Creole': 'hat_Latn',
70
+ 'Hausa': 'hau_Latn',
71
+ 'Hebrew': 'heb_Hebr',
72
+ 'Hindi': 'hin_Deva',
73
+ 'Chhattisgarhi': 'hne_Deva',
74
+ 'Croatian': 'hrv_Latn',
75
+ 'Hungarian': 'hun_Latn',
76
+ 'Armenian': 'hye_Armn',
77
+ 'Igbo': 'ibo_Latn',
78
+ 'Ilocano': 'ilo_Latn',
79
+ 'Indonesian': 'ind_Latn',
80
+ 'Icelandic': 'isl_Latn',
81
+ 'Italian': 'ita_Latn',
82
+ 'Javanese': 'jav_Latn',
83
+ 'Japanese': 'jpn_Jpan',
84
+ 'Kabyle': 'kab_Latn',
85
+ 'Jingpho': 'kac_Latn',
86
+ 'Kamba': 'kam_Latn',
87
+ 'Kannada': 'kan_Knda',
88
+ 'Kashmiri (Arabic script)': 'kas_Arab',
89
+ 'Kashmiri (Devanagari script)': 'kas_Deva',
90
+ 'Georgian': 'kat_Geor',
91
+ 'Central Kanuri (Arabic script)': 'knc_Arab',
92
+ 'Central Kanuri (Latin script)': 'knc_Latn',
93
+ 'Kazakh': 'kaz_Cyrl',
94
+ 'Kabiyè': 'kbp_Latn',
95
+ 'Kabuverdianu': 'kea_Latn',
96
+ 'Khmer': 'khm_Khmr',
97
+ 'Kikuyu': 'kik_Latn',
98
+ 'Kinyarwanda': 'kin_Latn', 'Kyrgyz': 'kir_Cyrl', 'Kimbundu': 'kmb_Latn',
99
+ 'Northern Kurdish': 'kmr_Latn', 'Kikongo': 'kon_Latn',
100
+ 'Korean': 'kor_Hang', 'Lao': 'lao_Laoo', 'Ligurian': 'lij_Latn',
101
+ 'Limburgish': 'lim_Latn', 'Lingala': 'lin_Latn', 'Lithuanian': 'lit_Latn', 'Lombard': 'lmo_Latn',
102
+ 'Latgalian': 'ltg_Latn', 'Luxembourgish': 'ltz_Latn', 'Luba-Kasai': 'lua_Latn', 'Ganda': 'lug_Latn',
103
+ 'Luo': 'luo_Latn', 'Mizo': 'lus_Latn', 'Standard Latvian': 'lvs_Latn', 'Magahi': 'mag_Deva',
104
+ 'Maithili': 'mai_Deva', 'Malayalam': 'mal_Mlym', 'Marathi': 'mar_Deva',
105
+ 'Minangkabau (Arabic script)': 'min_Arab', 'Minangkabau (Latin script)': 'min_Latn',
106
+ 'Macedonian': 'mkd_Cyrl', 'Plateau Malagasy': 'plt_Latn', 'Maltese': 'mlt_Latn',
107
+ 'Meitei (Bengali script)': 'mni_Beng', 'Halh Mongolian': 'khk_Cyrl', 'Mossi': 'mos_Latn',
108
+ 'Maori': 'mri_Latn', 'Burmese': 'mya_Mymr', 'Dutch': 'nld_Latn', 'Norwegian Nynorsk': 'nno_Latn',
109
+ 'Norwegian Bokmål': 'nob_Latn', 'Nepali': 'npi_Deva', 'Northern Sotho': 'nso_Latn',
110
+ 'Nuer': 'nus_Latn',
111
+ 'Nyanja': 'nya_Latn', 'Occitan': 'oci_Latn', 'West Central Oromo': 'gaz_Latn', 'Odia': 'ory_Orya',
112
+ 'Pangasinan': 'pag_Latn', 'Eastern Panjabi': 'pan_Guru', 'Papiamento': 'pap_Latn',
113
+ 'Western Persian': 'pes_Arab',
114
+ 'Polish': 'pol_Latn', 'Portuguese': 'por_Latn', 'Dari': 'prs_Arab', 'Southern Pashto': 'pbt_Arab',
115
+ 'Ayacucho Quechua': 'quy_Latn', 'Romanian': 'ron_Latn', 'Rundi': 'run_Latn', 'Russian': 'rus_Cyrl',
116
+ 'Sango': 'sag_Latn', 'Sanskrit': 'san_Deva', 'Santali': 'sat_Olck', 'Sicilian': 'scn_Latn',
117
+ 'Shan': 'shn_Mymr',
118
+ 'Sinhala': 'sin_Sinh', 'Slovak': 'slk_Latn', 'Slovenian': 'slv_Latn', 'Samoan': 'smo_Latn',
119
+ 'Shona': 'sna_Latn',
120
+ 'Sindhi': 'snd_Arab', 'Somali': 'som_Latn', 'Southern Sotho': 'sot_Latn', 'Spanish': 'spa_Latn',
121
+ 'Tosk Albanian': 'als_Latn', 'Sardinian': 'srd_Latn', 'Serbian': 'srp_Cyrl', 'Swati': 'ssw_Latn',
122
+ 'Sundanese': 'sun_Latn', 'Swedish': 'swe_Latn', 'Swahili': 'swh_Latn', 'Silesian': 'szl_Latn',
123
+ 'Tamil': 'tam_Taml', 'Tatar': 'tat_Cyrl', 'Telugu': 'tel_Telu', 'Tajik': 'tgk_Cyrl',
124
+ 'Tagalog': 'tgl_Latn',
125
+ 'Thai': 'tha_Thai', 'Tigrinya': 'tir_Ethi', 'Tamasheq (Latin script)': 'taq_Latn',
126
+ 'Tamasheq (Tifinagh script)': 'taq_Tfng',
127
+ 'Tok Pisin': 'tpi_Latn', 'Tswana': 'tsn_Latn', 'Tsonga': 'tso_Latn', 'Turkmen': 'tuk_Latn', 'Tumbuka': 'tum_Latn',
128
+ 'Turkish': 'tur_Latn', 'Twi': 'twi_Latn', 'Central Atlas Tamazight': 'tzm_Tfng',
129
+ 'Uyghur': 'uig_Arab',
130
+ 'Ukrainian': 'ukr_Cyrl', 'Umbundu': 'umb_Latn', 'Urdu': 'urd_Arab', 'Northern Uzbek': 'uzn_Latn',
131
+ 'Venetian': 'vec_Latn',
132
+ 'Vietnamese': 'vie_Latn', 'Waray': 'war_Latn', 'Wolof': 'wol_Latn', 'Xhosa': 'xho_Latn',
133
+ 'Eastern Yiddish': 'ydd_Hebr',
134
+ 'Yoruba': 'yor_Latn', 'Yue Chinese': 'yue_Hant', 'Chinese (Simplified)': 'zho_Hans',
135
+ 'Chinese (Traditional)': 'zho_Hant',
136
+ 'Standard Malay': 'zsm_Latn', 'Zulu': 'zul_Latn'}
137
+
138
+ self.model_name_dict = {'0.6B': 'facebook/nllb-200-distilled-600M',
139
+ '1.3B': 'facebook/nllb-200-distilled-1.3B',
140
+ '3.3B': 'facebook/nllb-200-3.3B',
141
+ }
142
+
143
+ self.whisper_codes_to_flores_codes = {"de" : self.flores_codes['German'],
144
+ "en" : self.flores_codes['English'],
145
+ "pl" : self.flores_codes['Polish'],
146
+ "hi" : self.flores_codes['Hindi']
147
+ }
148
+
149
+ self.flores_codes_to_tts_codes = {'Acehnese': 'ace', 'Mesopotamian Arabic': 'acm', 'Ta’izzi-Adeni Arabic': 'acq', 'Tunisian Arabic': 'aeb', 'Afrikaans': 'afr', 'South Levantine Arabic': 'ajp', 'Akan': 'aka', 'Amharic': 'amh', 'North Levantine Arabic': 'apc', 'Modern Standard Arabic': 'arb', 'Najdi Arabic': 'ars', 'Moroccan Arabic': 'ary', 'Egyptian Arabic': 'arz', 'Assamese': 'asm', 'Asturian': 'ast', 'Awadhi': 'awa', 'Central Aymara': 'ayr', 'South Azerbaijani': 'azb', 'North Azerbaijani': 'azj', 'Bashkir': 'bak', 'Bambara': 'bam', 'Balinese': 'ban', 'Belarusian': 'bel', 'Bemba': 'bem', 'Bengali': 'ben', 'Bhojpuri': 'bho', 'Banjar': 'bjn', 'Standard Tibetan': 'bod', 'Bosnian': 'bos', 'Buginese': 'bug', 'Bulgarian': 'bul', 'Catalan': 'cat', 'Cebuano': 'ceb', 'Czech': 'ces', 'Chokwe': 'cjk', 'Central Kurdish': 'ckb', 'Crimean Tatar': 'crh', 'Welsh': 'cym', 'Danish': 'dan', 'German': 'deu', 'Southwestern Dinka': 'dik', 'Dyula': 'dyu', 'Dzongkha': 'dzo', 'Greek': 'ell', 'English': 'eng', 'Esperanto': 'epo', 'Estonian': 'est', 'Basque': 'eus', 'Ewe': 'ewe', 'Faroese': 'fao', 'Fijian': 'fij', 'Finnish': 'fin', 'Fon': 'fon', 'French': 'fra', 'Friulian': 'fur', 'Nigerian Fulfulde': 'fuv', 'Scottish Gaelic': 'gla', 'Irish': 'gle', 'Galician': 'glg', 'Guarani': 'grn', 'Gujarati': 'guj', 'Haitian Creole': 'hat', 'Hausa': 'hau', 'Hebrew': 'heb', 'Hindi': 'hin', 'Chhattisgarhi': 'hne', 'Croatian': 'hrv', 'Hungarian': 'hun', 'Armenian': 'hye', 'Igbo': 'ibo', 'Ilocano': 'ilo', 'Indonesian': 'ind', 'Icelandic': 'isl', 'Italian': 'ita', 'Javanese': 'jav', 'Japanese': 'jpn', 'Kabyle': 'kab', 'Jingpho': 'kac', 'Kamba': 'kam', 'Kannada': 'kan', 'Kashmiri': 'kas', 'Georgian': 'kat', 'Central Kanuri': 'knc', 'Kazakh': 'kaz', 'Kabiyè': 'kbp', 'Kabuverdianu': 'kea', 'Khmer': 'khm', 'Kikuyu': 'kik', 'Kinyarwanda': 'kin', 'Kyrgyz': 'kir', 'Kimbundu': 'kmb', 'Northern Kurdish': 'kmr', 'Kikongo': 'kon', 'Korean': 'kor', 'Lao': 'lao', 'Ligurian': 'lij', 'Limburgish': 'lim', 'Lingala': 'lin', 'Lithuanian': 'lit', 'Lombard': 'lmo', 'Latgalian': 'ltg', 'Luxembourgish': 'ltz', 'Luba-Kasai': 'lua', 'Ganda': 'lug', 'Luo': 'luo', 'Mizo': 'lus', 'Standard Latvian': 'lvs', 'Magahi': 'mag', 'Maithili': 'mai', 'Malayalam': 'mal', 'Marathi': 'mar', 'Minangkabau': 'min', 'Macedonian': 'mkd', 'Plateau Malagasy': 'plt', 'Maltese': 'mlt', 'Meitei': 'mni', 'Halh Mongolian': 'khk', 'Mossi': 'mos', 'Maori': 'mri', 'Burmese': 'mya', 'Dutch': 'nld', 'Norwegian Nynorsk': 'nno', 'Norwegian Bokmål': 'nob', 'Nepali': 'npi', 'Northern Sotho': 'nso', 'Nuer': 'nus', 'Nyanja': 'nya', 'Occitan': 'oci', 'West Central Oromo': 'gaz', 'Odia': 'ory', 'Pangasinan': 'pag', 'Eastern Panjabi': 'pan', 'Papiamento': 'pap', 'Western Persian': 'pes', 'Polish': 'pol', 'Portuguese': 'por', 'Dari': 'prs', 'Southern Pashto': 'pbt', 'Ayacucho Quechua': 'quy', 'Romanian': 'ron', 'Rundi': 'run', 'Russian': 'rus', 'Sango': 'sag', 'Sanskrit': 'san', 'Santali': 'sat', 'Sicilian': 'scn', 'Shan': 'shn', 'Sinhala': 'sin', 'Slovak': 'slk', 'Slovenian': 'slv', 'Samoan': 'smo', 'Shona': 'sna', 'Sindhi': 'snd', 'Somali': 'som', 'Southern Sotho': 'sot', 'Spanish': 'spa', 'Tosk Albanian': 'als', 'Sardinian': 'srd', 'Serbian': 'srp', 'Swati': 'ssw', 'Sundanese': 'sun', 'Swedish': 'swe', 'Swahili': 'swh', 'Silesian': 'szl', 'Tamil': 'tam', 'Tatar': 'tat', 'Telugu': 'tel', 'Tajik': 'tgk', 'Tagalog': 'tgl', 'Thai': 'tha', 'Tigrinya': 'tir', 'Tamasheq': 'taq', 'Tok Pisin': 'tpi', 'Tswana': 'tsn', 'Tsonga': 'tso', 'Turkmen': 'tuk', 'Tumbuka': 'tum', 'Turkish': 'tur', 'Twi': 'twi', 'Central Atlas Tamazight': 'tzm', 'Uyghur': 'uig', 'Ukrainian': 'ukr', 'Umbundu': 'umb', 'Urdu': 'urd', 'Northern Uzbek': 'uzn', 'Venetian': 'vec', 'Vietnamese': 'vie', 'Waray': 'war', 'Wolof': 'wol', 'Xhosa': 'xho', 'Eastern Yiddish': 'ydd', 'Yoruba': 'yor', 'Yue Chinese': 'yue', 'Chinese': 'zho', 'Standard Malay': 'zsm', 'Zulu': 'zul'}
150
+
151
+ self.language_directory = 'Languages'
152
+ self.uroman_directory = 'aux_files'
153
+
154
+ self.language_download_web = 'https://dl.fbaipublicfiles.com/mms/tts'
155
+ self.language_vocab_text = "vocab.txt"
156
+ self.language_vocab_configuration = "config.json"
157
+ self.language_vocab_model = "G_100000.pth"
158
+
159
+ # creating the audio files temporary
160
+ # ---------------------------------------
161
+ self.temp_audio_folder = 'Temp_Audios'
162
+
163
+ self.text2speech_wavfile = f'{self.temp_audio_folder}/text2speech.wav'
164
+ self.enhanced_speech_file = f"{self.temp_audio_folder}/enhanced.mp3"
165
+ self.input_speech_file = f'{self.temp_audio_folder}/output.wav'
166
+
167
+
168
+ try:
169
+ os.makedirs(self.language_directory)
170
+ except:
171
+ pass
172
+
173
+ try:
174
+ os.makedirs(self.temp_audio_folder)
175
+ except:
176
+ pass
configurations/get_hyperparameters.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class hyperparameterConfig:
4
+ def __init__(self):
5
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6
+
7
+ self.stt_model = "large-v2"
8
+ self.nllb_model = '1.3B'
9
+
10
+ # text to speech model
11
+ self.text2speech_model = 'model_weights/voiceover/freevc-24.pth'
12
+ self.text2speech_config = 'model_weights/voiceover/freevc-24.json'
13
+ self.text2speech_encoder = 'model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt'
14
+
15
+ # voice enhancing model
16
+ self.voice_enhacing_model = 'model_weights/voice_enhance'
17
+
18
+ # loading the wavlm model
19
+ self.wavlm_model = 'model_weights/wavlm_models/WavLM-Large.pt'
df/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .config import config
2
+
3
+ __all__ = ["config"]
df/checkpoint.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import re
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from loguru import logger
9
+ from torch import nn
10
+
11
+ from df.config import Csv, config
12
+ from df.model import init_model
13
+ from df.utils import check_finite_module
14
+ from libdf import DF
15
+
16
+
17
+ def get_epoch(cp) -> int:
18
+ return int(os.path.basename(cp).split(".")[0].split("_")[-1])
19
+
20
+
21
+ def load_model(
22
+ cp_dir: Optional[str],
23
+ df_state: DF,
24
+ jit: bool = False,
25
+ mask_only: bool = False,
26
+ train_df_only: bool = False,
27
+ extension: str = "ckpt",
28
+ epoch: Union[str, int, None] = "latest",
29
+ ) -> Tuple[nn.Module, int]:
30
+ if mask_only and train_df_only:
31
+ raise ValueError("Only one of `mask_only` `train_df_only` can be enabled")
32
+ model = init_model(df_state, run_df=mask_only is False, train_mask=train_df_only is False)
33
+ if jit:
34
+ model = torch.jit.script(model)
35
+ blacklist: List[str] = config("CP_BLACKLIST", [], Csv(), save=False, section="train") # type: ignore
36
+ if cp_dir is not None:
37
+ epoch = read_cp(
38
+ model, "model", cp_dir, blacklist=blacklist, extension=extension, epoch=epoch
39
+ )
40
+ epoch = 0 if epoch is None else epoch
41
+ else:
42
+ epoch = 0
43
+ return model, epoch
44
+
45
+
46
+ def read_cp(
47
+ obj: Union[torch.optim.Optimizer, nn.Module],
48
+ name: str,
49
+ dirname: str,
50
+ epoch: Union[str, int, None] = "latest",
51
+ extension="ckpt",
52
+ blacklist=[],
53
+ log: bool = True,
54
+ ):
55
+ checkpoints = []
56
+ if isinstance(epoch, str):
57
+ assert epoch in ("best", "latest")
58
+ if epoch == "best":
59
+ checkpoints = glob.glob(os.path.join(dirname, f"{name}*.{extension}.best"))
60
+ if len(checkpoints) == 0:
61
+ logger.warning("Could not find `best` checkpoint. Checking for default...")
62
+ if len(checkpoints) == 0:
63
+ checkpoints = glob.glob(os.path.join(dirname, f"{name}*.{extension}"))
64
+ checkpoints += glob.glob(os.path.join(dirname, f"{name}*.{extension}.best"))
65
+ if len(checkpoints) == 0:
66
+ return None
67
+ if isinstance(epoch, int):
68
+ latest = next((x for x in checkpoints if get_epoch(x) == epoch), None)
69
+ if latest is None:
70
+ logger.error(f"Could not find checkpoint of epoch {epoch}")
71
+ exit(1)
72
+ else:
73
+ latest = max(checkpoints, key=get_epoch)
74
+ epoch = get_epoch(latest)
75
+ if log:
76
+ logger.info("Found checkpoint {} with epoch {}".format(latest, epoch))
77
+ latest = torch.load(latest, map_location="cpu")
78
+ latest = {k.replace("clc", "df"): v for k, v in latest.items()}
79
+ if blacklist:
80
+ reg = re.compile("".join(f"({b})|" for b in blacklist)[:-1])
81
+ len_before = len(latest)
82
+ latest = {k: v for k, v in latest.items() if reg.search(k) is None}
83
+ if len(latest) < len_before:
84
+ logger.info("Filtered checkpoint modules: {}".format(blacklist))
85
+ if isinstance(obj, nn.Module):
86
+ while True:
87
+ try:
88
+ missing, unexpected = obj.load_state_dict(latest, strict=False)
89
+ except RuntimeError as e:
90
+ e_str = str(e)
91
+ logger.warning(e_str)
92
+ if "size mismatch" in e_str:
93
+ latest = {k: v for k, v in latest.items() if k not in e_str}
94
+ continue
95
+ raise e
96
+ break
97
+ for key in missing:
98
+ logger.warning(f"Missing key: '{key}'")
99
+ for key in unexpected:
100
+ if key.endswith(".h0"):
101
+ continue
102
+ logger.warning(f"Unexpected key: {key}")
103
+ return epoch
104
+ obj.load_state_dict(latest)
105
+
106
+
107
+ def write_cp(
108
+ obj: Union[torch.optim.Optimizer, nn.Module],
109
+ name: str,
110
+ dirname: str,
111
+ epoch: int,
112
+ extension="ckpt",
113
+ metric: Optional[float] = None,
114
+ cmp="min",
115
+ ):
116
+ check_finite_module(obj)
117
+ n_keep = config("n_checkpoint_history", default=3, cast=int, section="train")
118
+ n_keep_best = config("n_best_checkpoint_history", default=5, cast=int, section="train")
119
+ if metric is not None:
120
+ assert cmp in ("min", "max")
121
+ metric = float(metric) # Make sure it is not an integer
122
+ # Each line contains a previous best with entries: (epoch, metric)
123
+ with open(os.path.join(dirname, ".best"), "a+") as prev_best_f:
124
+ prev_best_f.seek(0) # "a+" creates a file in read/write mode without truncating
125
+ lines = prev_best_f.readlines()
126
+ if len(lines) == 0:
127
+ prev_best = float("inf" if cmp == "min" else "-inf")
128
+ else:
129
+ prev_best = float(lines[-1].strip().split(" ")[1])
130
+ cmp = "__lt__" if cmp == "min" else "__gt__"
131
+ if getattr(metric, cmp)(prev_best):
132
+ logger.info(f"Saving new best checkpoint at epoch {epoch} with metric: {metric}")
133
+ prev_best_f.seek(0, os.SEEK_END)
134
+ np.savetxt(prev_best_f, np.array([[float(epoch), metric]]))
135
+ cp_name = os.path.join(dirname, f"{name}_{epoch}.{extension}.best")
136
+ torch.save(obj.state_dict(), cp_name)
137
+ cleanup(name, dirname, extension + ".best", nkeep=n_keep_best)
138
+ cp_name = os.path.join(dirname, f"{name}_{epoch}.{extension}")
139
+ logger.info(f"Writing checkpoint {cp_name} with epoch {epoch}")
140
+ torch.save(obj.state_dict(), cp_name)
141
+ cleanup(name, dirname, extension, nkeep=n_keep)
142
+
143
+
144
+ def cleanup(name: str, dirname: str, extension: str, nkeep=5):
145
+ if nkeep < 0:
146
+ return
147
+ checkpoints = glob.glob(os.path.join(dirname, f"{name}*.{extension}"))
148
+ if len(checkpoints) == 0:
149
+ return
150
+ checkpoints = sorted(checkpoints, key=get_epoch, reverse=True)
151
+ for cp in checkpoints[nkeep:]:
152
+ logger.debug("Removing old checkpoint: {}".format(cp))
153
+ os.remove(cp)
154
+
155
+
156
+ def check_patience(
157
+ dirname: str, max_patience: int, new_metric: float, cmp: str = "min", raise_: bool = True
158
+ ):
159
+ cmp = "__lt__" if cmp == "min" else "__gt__"
160
+ new_metric = float(new_metric) # Make sure it is not an integer
161
+ prev_patience, prev_metric = read_patience(dirname)
162
+ if prev_patience is None or getattr(new_metric, cmp)(prev_metric):
163
+ # We have a better new_metric, reset patience
164
+ write_patience(dirname, 0, new_metric)
165
+ else:
166
+ # We don't have a better metric, decrement patience
167
+ new_patience = prev_patience + 1
168
+ write_patience(dirname, new_patience, prev_metric)
169
+ if new_patience >= max_patience:
170
+ if raise_:
171
+ raise ValueError(
172
+ f"No improvements on validation metric ({new_metric}) for {max_patience} epochs. "
173
+ "Stopping."
174
+ )
175
+ else:
176
+ return False
177
+ return True
178
+
179
+
180
+ def read_patience(dirname: str) -> Tuple[Optional[int], float]:
181
+ fn = os.path.join(dirname, ".patience")
182
+ if not os.path.isfile(fn):
183
+ return None, 0.0
184
+ patience, metric = np.loadtxt(fn)
185
+ return int(patience), float(metric)
186
+
187
+
188
+ def write_patience(dirname: str, new_patience: int, metric: float):
189
+ return np.savetxt(os.path.join(dirname, ".patience"), [new_patience, metric])
190
+
191
+
192
+ def test_check_patience():
193
+ import tempfile
194
+
195
+ with tempfile.TemporaryDirectory() as d:
196
+ check_patience(d, 3, 1.0)
197
+ check_patience(d, 3, 1.0)
198
+ check_patience(d, 3, 1.0)
199
+ assert check_patience(d, 3, 1.0, raise_=False) is False
200
+
201
+ with tempfile.TemporaryDirectory() as d:
202
+ check_patience(d, 3, 1.0)
203
+ check_patience(d, 3, 0.9)
204
+ check_patience(d, 3, 1.0)
205
+ check_patience(d, 3, 1.0)
206
+ assert check_patience(d, 3, 1.0, raise_=False) is False
207
+
208
+ with tempfile.TemporaryDirectory() as d:
209
+ check_patience(d, 3, 1.0, cmp="max")
210
+ check_patience(d, 3, 1.9, cmp="max")
211
+ check_patience(d, 3, 1.0, cmp="max")
212
+ check_patience(d, 3, 1.0, cmp="max")
213
+ assert check_patience(d, 3, 1.0, cmp="max", raise_=False) is False
df/config.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import string
3
+ from configparser import ConfigParser
4
+ from shlex import shlex
5
+ from typing import Any, List, Optional, Tuple, Type, TypeVar, Union
6
+
7
+ from loguru import logger
8
+
9
+ T = TypeVar("T")
10
+
11
+
12
+ class DfParams:
13
+ def __init__(self):
14
+ # Sampling rate used for training
15
+ self.sr: int = config("SR", cast=int, default=48_000, section="DF")
16
+ # FFT size in samples
17
+ self.fft_size: int = config("FFT_SIZE", cast=int, default=960, section="DF")
18
+ # STFT Hop size in samples
19
+ self.hop_size: int = config("HOP_SIZE", cast=int, default=480, section="DF")
20
+ # Number of ERB bands
21
+ self.nb_erb: int = config("NB_ERB", cast=int, default=32, section="DF")
22
+ # Number of deep filtering bins; DF is applied from 0th to nb_df-th frequency bins
23
+ self.nb_df: int = config("NB_DF", cast=int, default=96, section="DF")
24
+ # Normalization decay factor; used for complex and erb features
25
+ self.norm_tau: float = config("NORM_TAU", 1, float, section="DF")
26
+ # Local SNR minimum value, ground truth will be truncated
27
+ self.lsnr_max: int = config("LSNR_MAX", 35, int, section="DF")
28
+ # Local SNR maximum value, ground truth will be truncated
29
+ self.lsnr_min: int = config("LSNR_MIN", -15, int, section="DF")
30
+ # Minimum number of frequency bins per ERB band
31
+ self.min_nb_freqs = config("MIN_NB_ERB_FREQS", 2, int, section="DF")
32
+ # Deep Filtering order
33
+ self.df_order: int = config("DF_ORDER", cast=int, default=5, section="DF")
34
+ # Deep Filtering look-ahead
35
+ self.df_lookahead: int = config("DF_LOOKAHEAD", cast=int, default=0, section="DF")
36
+ # Pad mode. By default, padding will be handled on the input side:
37
+ # - `input`, which pads the input features passed to the model
38
+ # - `output`, which pads the output spectrogram corresponding to `df_lookahead`
39
+ self.pad_mode: str = config("PAD_MODE", default="input_specf", section="DF")
40
+
41
+
42
+ class Config:
43
+ """Adopted from python-decouple"""
44
+
45
+ DEFAULT_SECTION = "settings"
46
+
47
+ def __init__(self):
48
+ self.parser: ConfigParser = None # type: ignore
49
+ self.path = ""
50
+ self.modified = False
51
+ self.allow_defaults = True
52
+
53
+ def load(
54
+ self, path: Optional[str], config_must_exist=False, allow_defaults=True, allow_reload=False
55
+ ):
56
+ self.allow_defaults = allow_defaults
57
+ if self.parser is not None and not allow_reload:
58
+ raise ValueError("Config already loaded")
59
+ self.parser = ConfigParser()
60
+ self.path = path
61
+ if path is not None and os.path.isfile(path):
62
+ with open(path) as f:
63
+ self.parser.read_file(f)
64
+ else:
65
+ if config_must_exist:
66
+ raise ValueError(f"No config file found at '{path}'.")
67
+ if not self.parser.has_section(self.DEFAULT_SECTION):
68
+ self.parser.add_section(self.DEFAULT_SECTION)
69
+ self._fix_clc()
70
+ self._fix_df()
71
+
72
+ def use_defaults(self):
73
+ self.load(path=None, config_must_exist=False)
74
+
75
+ def save(self, path: str):
76
+ if not self.modified:
77
+ logger.debug("Config not modified. No need to overwrite on disk.")
78
+ return
79
+ if self.parser is None:
80
+ self.parser = ConfigParser()
81
+ for section in self.parser.sections():
82
+ if len(self.parser[section]) == 0:
83
+ self.parser.remove_section(section)
84
+ with open(path, mode="w") as f:
85
+ self.parser.write(f)
86
+
87
+ def tostr(self, value, cast):
88
+ if isinstance(cast, Csv) and isinstance(value, (tuple, list)):
89
+ return "".join(str(v) + cast.delimiter for v in value)[:-1]
90
+ return str(value)
91
+
92
+ def set(self, option: str, value: T, cast: Type[T], section: Optional[str] = None) -> T:
93
+ section = self.DEFAULT_SECTION if section is None else section
94
+ section = section.lower()
95
+ if not self.parser.has_section(section):
96
+ self.parser.add_section(section)
97
+ if self.parser.has_option(section, option):
98
+ if value == self.cast(self.parser.get(section, option), cast):
99
+ return value
100
+ self.modified = True
101
+ self.parser.set(section, option, self.tostr(value, cast))
102
+ return value
103
+
104
+ def __call__(
105
+ self,
106
+ option: str,
107
+ default: Any = None,
108
+ cast: Type[T] = str,
109
+ save: bool = True,
110
+ section: Optional[str] = None,
111
+ ) -> T:
112
+ # Get value either from an ENV or from the .ini file
113
+ section = self.DEFAULT_SECTION if section is None else section
114
+ value = None
115
+ if self.parser is None:
116
+ raise ValueError("No configuration loaded")
117
+ if not self.parser.has_section(section.lower()):
118
+ self.parser.add_section(section.lower())
119
+ if option in os.environ:
120
+ value = os.environ[option]
121
+ if save:
122
+ self.parser.set(section, option, self.tostr(value, cast))
123
+ elif self.parser.has_option(section, option):
124
+ value = self.read_from_section(section, option, default, cast=cast, save=save)
125
+ elif self.parser.has_option(section.lower(), option):
126
+ value = self.read_from_section(section.lower(), option, default, cast=cast, save=save)
127
+ elif self.parser.has_option(self.DEFAULT_SECTION, option):
128
+ logger.warning(
129
+ f"Couldn't find option {option} in section {section}. "
130
+ "Falling back to default settings section."
131
+ )
132
+ value = self.read_from_section(self.DEFAULT_SECTION, option, cast=cast, save=save)
133
+ elif default is None:
134
+ raise ValueError("Value {} not found.".format(option))
135
+ elif not self.allow_defaults and save:
136
+ raise ValueError(f"Value '{option}' not found in config (defaults not allowed).")
137
+ else:
138
+ value = default
139
+ if save:
140
+ self.set(option, value, cast, section)
141
+ return self.cast(value, cast)
142
+
143
+ def cast(self, value, cast):
144
+ # Do the casting to get the correct type
145
+ if cast is bool:
146
+ value = str(value).lower()
147
+ if value in {"true", "yes", "y", "on", "1"}:
148
+ return True # type: ignore
149
+ elif value in {"false", "no", "n", "off", "0"}:
150
+ return False # type: ignore
151
+ raise ValueError("Parse error")
152
+ return cast(value)
153
+
154
+ def get(self, option: str, cast: Type[T] = str, section: Optional[str] = None) -> T:
155
+ section = self.DEFAULT_SECTION if section is None else section
156
+ if not self.parser.has_section(section):
157
+ raise KeyError(section)
158
+ if not self.parser.has_option(section, option):
159
+ raise KeyError(option)
160
+ return self.cast(self.parser.get(section, option), cast)
161
+
162
+ def read_from_section(
163
+ self, section: str, option: str, default: Any = None, cast: Type = str, save: bool = True
164
+ ) -> str:
165
+ value = self.parser.get(section, option)
166
+ if not save:
167
+ # Set to default or remove to not read it at trainig start again
168
+ if default is None:
169
+ self.parser.remove_option(section, option)
170
+ elif not self.allow_defaults:
171
+ raise ValueError(f"Value '{option}' not found in config (defaults not allowed).")
172
+ else:
173
+ self.parser.set(section, option, self.tostr(default, cast))
174
+ elif section.lower() != section:
175
+ self.parser.set(section.lower(), option, self.tostr(value, cast))
176
+ self.parser.remove_option(section, option)
177
+ self.modified = True
178
+ return value
179
+
180
+ def overwrite(self, section: str, option: str, value: Any):
181
+ if not self.parser.has_section(section):
182
+ return ValueError(f"Section not found: '{section}'")
183
+ if not self.parser.has_option(section, option):
184
+ return ValueError(f"Option not found '{option}' in section '{section}'")
185
+ self.modified = True
186
+ cast = type(value)
187
+ return self.parser.set(section, option, self.tostr(value, cast))
188
+
189
+ def _fix_df(self):
190
+ """Renaming of some groups/options for compatibility with old models."""
191
+ if self.parser.has_section("deepfilternet") and self.parser.has_section("df"):
192
+ sec_deepfilternet = self.parser["deepfilternet"]
193
+ sec_df = self.parser["df"]
194
+ if "df_order" in sec_deepfilternet:
195
+ sec_df["df_order"] = sec_deepfilternet["df_order"]
196
+ del sec_deepfilternet["df_order"]
197
+ if "df_lookahead" in sec_deepfilternet:
198
+ sec_df["df_lookahead"] = sec_deepfilternet["df_lookahead"]
199
+ del sec_deepfilternet["df_lookahead"]
200
+
201
+ def _fix_clc(self):
202
+ """Renaming of some groups/options for compatibility with old models."""
203
+ if (
204
+ not self.parser.has_section("deepfilternet")
205
+ and self.parser.has_section("train")
206
+ and self.parser.get("train", "model") == "convgru5"
207
+ ):
208
+ self.overwrite("train", "model", "deepfilternet")
209
+ self.parser.add_section("deepfilternet")
210
+ self.parser["deepfilternet"] = self.parser["convgru"]
211
+ del self.parser["convgru"]
212
+ if not self.parser.has_section("df") and self.parser.has_section("clc"):
213
+ self.parser["df"] = self.parser["clc"]
214
+ del self.parser["clc"]
215
+ for section in self.parser.sections():
216
+ for k, v in self.parser[section].items():
217
+ if "clc" in k.lower():
218
+ self.parser.set(section, k.lower().replace("clc", "df"), v)
219
+ del self.parser[section][k]
220
+
221
+ def __repr__(self):
222
+ msg = ""
223
+ for section in self.parser.sections():
224
+ msg += f"{section}:\n"
225
+ for k, v in self.parser[section].items():
226
+ msg += f" {k}: {v}\n"
227
+ return msg
228
+
229
+
230
+ config = Config()
231
+
232
+
233
+ class Csv(object):
234
+ """
235
+ Produces a csv parser that return a list of transformed elements. From python-decouple.
236
+ """
237
+
238
+ def __init__(
239
+ self, cast: Type[T] = str, delimiter=",", strip=string.whitespace, post_process=list
240
+ ):
241
+ """
242
+ Parameters:
243
+ cast -- callable that transforms the item just before it's added to the list.
244
+ delimiter -- string of delimiters chars passed to shlex.
245
+ strip -- string of non-relevant characters to be passed to str.strip after the split.
246
+ post_process -- callable to post process all casted values. Default is `list`.
247
+ """
248
+ self.cast: Type[T] = cast
249
+ self.delimiter = delimiter
250
+ self.strip = strip
251
+ self.post_process = post_process
252
+
253
+ def __call__(self, value: Union[str, Tuple[T], List[T]]) -> List[T]:
254
+ """The actual transformation"""
255
+ if isinstance(value, (tuple, list)):
256
+ # if default value is a list
257
+ value = "".join(str(v) + self.delimiter for v in value)[:-1]
258
+
259
+ def transform(s):
260
+ return self.cast(s.strip(self.strip))
261
+
262
+ splitter = shlex(value, posix=True)
263
+ splitter.whitespace = self.delimiter
264
+ splitter.whitespace_split = True
265
+
266
+ return self.post_process(transform(s) for s in splitter)
df/deepfilternet2.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Final, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from loguru import logger
6
+ from torch import Tensor, nn
7
+
8
+ from df.config import Csv, DfParams, config
9
+ from df.modules import (
10
+ Conv2dNormAct,
11
+ ConvTranspose2dNormAct,
12
+ DfOp,
13
+ GroupedGRU,
14
+ GroupedLinear,
15
+ GroupedLinearEinsum,
16
+ Mask,
17
+ SqueezedGRU,
18
+ erb_fb,
19
+ get_device,
20
+ )
21
+ from df.multiframe import MF_METHODS, MultiFrameModule
22
+ from libdf import DF
23
+
24
+
25
+ class ModelParams(DfParams):
26
+ section = "deepfilternet"
27
+
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.conv_lookahead: int = config(
31
+ "CONV_LOOKAHEAD", cast=int, default=0, section=self.section
32
+ )
33
+ self.conv_ch: int = config("CONV_CH", cast=int, default=16, section=self.section)
34
+ self.conv_depthwise: bool = config(
35
+ "CONV_DEPTHWISE", cast=bool, default=True, section=self.section
36
+ )
37
+ self.convt_depthwise: bool = config(
38
+ "CONVT_DEPTHWISE", cast=bool, default=True, section=self.section
39
+ )
40
+ self.conv_kernel: List[int] = config(
41
+ "CONV_KERNEL", cast=Csv(int), default=(1, 3), section=self.section # type: ignore
42
+ )
43
+ self.conv_kernel_inp: List[int] = config(
44
+ "CONV_KERNEL_INP", cast=Csv(int), default=(3, 3), section=self.section # type: ignore
45
+ )
46
+ self.emb_hidden_dim: int = config(
47
+ "EMB_HIDDEN_DIM", cast=int, default=256, section=self.section
48
+ )
49
+ self.emb_num_layers: int = config(
50
+ "EMB_NUM_LAYERS", cast=int, default=2, section=self.section
51
+ )
52
+ self.df_hidden_dim: int = config(
53
+ "DF_HIDDEN_DIM", cast=int, default=256, section=self.section
54
+ )
55
+ self.df_gru_skip: str = config("DF_GRU_SKIP", default="none", section=self.section)
56
+ self.df_output_layer: str = config(
57
+ "DF_OUTPUT_LAYER", default="linear", section=self.section
58
+ )
59
+ self.df_pathway_kernel_size_t: int = config(
60
+ "DF_PATHWAY_KERNEL_SIZE_T", cast=int, default=1, section=self.section
61
+ )
62
+ self.enc_concat: bool = config("ENC_CONCAT", cast=bool, default=False, section=self.section)
63
+ self.df_num_layers: int = config("DF_NUM_LAYERS", cast=int, default=3, section=self.section)
64
+ self.df_n_iter: int = config("DF_N_ITER", cast=int, default=2, section=self.section)
65
+ self.gru_type: str = config("GRU_TYPE", default="grouped", section=self.section)
66
+ self.gru_groups: int = config("GRU_GROUPS", cast=int, default=1, section=self.section)
67
+ self.lin_groups: int = config("LINEAR_GROUPS", cast=int, default=1, section=self.section)
68
+ self.group_shuffle: bool = config(
69
+ "GROUP_SHUFFLE", cast=bool, default=True, section=self.section
70
+ )
71
+ self.dfop_method: str = config("DFOP_METHOD", cast=str, default="df", section=self.section)
72
+ self.mask_pf: bool = config("MASK_PF", cast=bool, default=False, section=self.section)
73
+
74
+
75
+ def init_model(df_state: Optional[DF] = None, run_df: bool = True, train_mask: bool = True):
76
+ p = ModelParams()
77
+ if df_state is None:
78
+ df_state = DF(sr=p.sr, fft_size=p.fft_size, hop_size=p.hop_size, nb_bands=p.nb_erb)
79
+ erb = erb_fb(df_state.erb_widths(), p.sr, inverse=False)
80
+ erb_inverse = erb_fb(df_state.erb_widths(), p.sr, inverse=True)
81
+ model = DfNet(erb, erb_inverse, run_df, train_mask)
82
+ return model.to(device=get_device())
83
+
84
+
85
+ class Add(nn.Module):
86
+ def forward(self, a, b):
87
+ return a + b
88
+
89
+
90
+ class Concat(nn.Module):
91
+ def forward(self, a, b):
92
+ return torch.cat((a, b), dim=-1)
93
+
94
+
95
+ class Encoder(nn.Module):
96
+ def __init__(self):
97
+ super().__init__()
98
+ p = ModelParams()
99
+ assert p.nb_erb % 4 == 0, "erb_bins should be divisible by 4"
100
+
101
+ self.erb_conv0 = Conv2dNormAct(
102
+ 1, p.conv_ch, kernel_size=p.conv_kernel_inp, bias=False, separable=True
103
+ )
104
+ conv_layer = partial(
105
+ Conv2dNormAct,
106
+ in_ch=p.conv_ch,
107
+ out_ch=p.conv_ch,
108
+ kernel_size=p.conv_kernel,
109
+ bias=False,
110
+ separable=True,
111
+ )
112
+ self.erb_conv1 = conv_layer(fstride=2)
113
+ self.erb_conv2 = conv_layer(fstride=2)
114
+ self.erb_conv3 = conv_layer(fstride=1)
115
+ self.df_conv0 = Conv2dNormAct(
116
+ 2, p.conv_ch, kernel_size=p.conv_kernel_inp, bias=False, separable=True
117
+ )
118
+ self.df_conv1 = conv_layer(fstride=2)
119
+ self.erb_bins = p.nb_erb
120
+ self.emb_in_dim = p.conv_ch * p.nb_erb // 4
121
+ self.emb_out_dim = p.emb_hidden_dim
122
+ if p.gru_type == "grouped":
123
+ self.df_fc_emb = GroupedLinear(
124
+ p.conv_ch * p.nb_df // 2, self.emb_in_dim, groups=p.lin_groups
125
+ )
126
+ else:
127
+ df_fc_emb = GroupedLinearEinsum(
128
+ p.conv_ch * p.nb_df // 2, self.emb_in_dim, groups=p.lin_groups
129
+ )
130
+ self.df_fc_emb = nn.Sequential(df_fc_emb, nn.ReLU(inplace=True))
131
+ if p.enc_concat:
132
+ self.emb_in_dim *= 2
133
+ self.combine = Concat()
134
+ else:
135
+ self.combine = Add()
136
+ self.emb_out_dim = p.emb_hidden_dim
137
+ self.emb_n_layers = p.emb_num_layers
138
+ assert p.gru_type in ("grouped", "squeeze"), f"But got {p.gru_type}"
139
+ if p.gru_type == "grouped":
140
+ self.emb_gru = GroupedGRU(
141
+ self.emb_in_dim,
142
+ self.emb_out_dim,
143
+ num_layers=1,
144
+ batch_first=True,
145
+ groups=p.gru_groups,
146
+ shuffle=p.group_shuffle,
147
+ add_outputs=True,
148
+ )
149
+ else:
150
+ self.emb_gru = SqueezedGRU(
151
+ self.emb_in_dim,
152
+ self.emb_out_dim,
153
+ num_layers=1,
154
+ batch_first=True,
155
+ linear_groups=p.lin_groups,
156
+ linear_act_layer=partial(nn.ReLU, inplace=True),
157
+ )
158
+ self.lsnr_fc = nn.Sequential(nn.Linear(self.emb_out_dim, 1), nn.Sigmoid())
159
+ self.lsnr_scale = p.lsnr_max - p.lsnr_min
160
+ self.lsnr_offset = p.lsnr_min
161
+
162
+ def forward(
163
+ self, feat_erb: Tensor, feat_spec: Tensor
164
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
165
+ # Encodes erb; erb should be in dB scale + normalized; Fe are number of erb bands.
166
+ # erb: [B, 1, T, Fe]
167
+ # spec: [B, 2, T, Fc]
168
+ # b, _, t, _ = feat_erb.shape
169
+ e0 = self.erb_conv0(feat_erb) # [B, C, T, F]
170
+ e1 = self.erb_conv1(e0) # [B, C*2, T, F/2]
171
+ e2 = self.erb_conv2(e1) # [B, C*4, T, F/4]
172
+ e3 = self.erb_conv3(e2) # [B, C*4, T, F/4]
173
+ c0 = self.df_conv0(feat_spec) # [B, C, T, Fc]
174
+ c1 = self.df_conv1(c0) # [B, C*2, T, Fc]
175
+ cemb = c1.permute(0, 2, 3, 1).flatten(2) # [B, T, -1]
176
+ cemb = self.df_fc_emb(cemb) # [T, B, C * F/4]
177
+ emb = e3.permute(0, 2, 3, 1).flatten(2) # [B, T, C * F/4]
178
+ emb = self.combine(emb, cemb)
179
+ emb, _ = self.emb_gru(emb) # [B, T, -1]
180
+ lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
181
+ return e0, e1, e2, e3, emb, c0, lsnr
182
+
183
+
184
+ class ErbDecoder(nn.Module):
185
+ def __init__(self):
186
+ super().__init__()
187
+ p = ModelParams()
188
+ assert p.nb_erb % 8 == 0, "erb_bins should be divisible by 8"
189
+
190
+ self.emb_out_dim = p.emb_hidden_dim
191
+
192
+ if p.gru_type == "grouped":
193
+ self.emb_gru = GroupedGRU(
194
+ p.conv_ch * p.nb_erb // 4, # For compat
195
+ self.emb_out_dim,
196
+ num_layers=p.emb_num_layers - 1,
197
+ batch_first=True,
198
+ groups=p.gru_groups,
199
+ shuffle=p.group_shuffle,
200
+ add_outputs=True,
201
+ )
202
+ # SqueezedGRU uses GroupedLinearEinsum, so let's use it here as well
203
+ fc_emb = GroupedLinear(
204
+ p.emb_hidden_dim,
205
+ p.conv_ch * p.nb_erb // 4,
206
+ groups=p.lin_groups,
207
+ shuffle=p.group_shuffle,
208
+ )
209
+ self.fc_emb = nn.Sequential(fc_emb, nn.ReLU(inplace=True))
210
+ else:
211
+ self.emb_gru = SqueezedGRU(
212
+ self.emb_out_dim,
213
+ self.emb_out_dim,
214
+ output_size=p.conv_ch * p.nb_erb // 4,
215
+ num_layers=p.emb_num_layers - 1,
216
+ batch_first=True,
217
+ gru_skip_op=nn.Identity,
218
+ linear_groups=p.lin_groups,
219
+ linear_act_layer=partial(nn.ReLU, inplace=True),
220
+ )
221
+ self.fc_emb = nn.Identity()
222
+ tconv_layer = partial(
223
+ ConvTranspose2dNormAct,
224
+ kernel_size=p.conv_kernel,
225
+ bias=False,
226
+ separable=True,
227
+ )
228
+ conv_layer = partial(
229
+ Conv2dNormAct,
230
+ bias=False,
231
+ separable=True,
232
+ )
233
+ # convt: TransposedConvolution, convp: Pathway (encoder to decoder) convolutions
234
+ self.conv3p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1)
235
+ self.convt3 = conv_layer(p.conv_ch, p.conv_ch, kernel_size=p.conv_kernel)
236
+ self.conv2p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1)
237
+ self.convt2 = tconv_layer(p.conv_ch, p.conv_ch, fstride=2)
238
+ self.conv1p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1)
239
+ self.convt1 = tconv_layer(p.conv_ch, p.conv_ch, fstride=2)
240
+ self.conv0p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1)
241
+ self.conv0_out = conv_layer(
242
+ p.conv_ch, 1, kernel_size=p.conv_kernel, activation_layer=nn.Sigmoid
243
+ )
244
+
245
+ def forward(self, emb, e3, e2, e1, e0) -> Tensor:
246
+ # Estimates erb mask
247
+ b, _, t, f8 = e3.shape
248
+ emb, _ = self.emb_gru(emb)
249
+ emb = self.fc_emb(emb)
250
+ emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) # [B, C*8, T, F/8]
251
+ e3 = self.convt3(self.conv3p(e3) + emb) # [B, C*4, T, F/4]
252
+ e2 = self.convt2(self.conv2p(e2) + e3) # [B, C*2, T, F/2]
253
+ e1 = self.convt1(self.conv1p(e1) + e2) # [B, C, T, F]
254
+ m = self.conv0_out(self.conv0p(e0) + e1) # [B, 1, T, F]
255
+ return m
256
+
257
+
258
+ class DfOutputReshapeMF(nn.Module):
259
+ """Coefficients output reshape for multiframe/MultiFrameModule
260
+
261
+ Requires input of shape B, C, T, F, 2.
262
+ """
263
+
264
+ def __init__(self, df_order: int, df_bins: int):
265
+ super().__init__()
266
+ self.df_order = df_order
267
+ self.df_bins = df_bins
268
+
269
+ def forward(self, coefs: Tensor) -> Tensor:
270
+ # [B, T, F, O*2] -> [B, O, T, F, 2]
271
+ coefs = coefs.view(*coefs.shape[:-1], -1, 2)
272
+ coefs = coefs.permute(0, 3, 1, 2, 4)
273
+ return coefs
274
+
275
+
276
+ class DfDecoder(nn.Module):
277
+ def __init__(self, out_channels: int = -1):
278
+ super().__init__()
279
+ p = ModelParams()
280
+ layer_width = p.conv_ch
281
+ self.emb_dim = p.emb_hidden_dim
282
+
283
+ self.df_n_hidden = p.df_hidden_dim
284
+ self.df_n_layers = p.df_num_layers
285
+ self.df_order = p.df_order
286
+ self.df_bins = p.nb_df
287
+ self.gru_groups = p.gru_groups
288
+ self.df_out_ch = out_channels if out_channels > 0 else p.df_order * 2
289
+
290
+ conv_layer = partial(Conv2dNormAct, separable=True, bias=False)
291
+ kt = p.df_pathway_kernel_size_t
292
+ self.df_convp = conv_layer(layer_width, self.df_out_ch, fstride=1, kernel_size=(kt, 1))
293
+ if p.gru_type == "grouped":
294
+ self.df_gru = GroupedGRU(
295
+ p.emb_hidden_dim,
296
+ p.df_hidden_dim,
297
+ num_layers=self.df_n_layers,
298
+ batch_first=True,
299
+ groups=p.gru_groups,
300
+ shuffle=p.group_shuffle,
301
+ add_outputs=True,
302
+ )
303
+ else:
304
+ self.df_gru = SqueezedGRU(
305
+ p.emb_hidden_dim,
306
+ p.df_hidden_dim,
307
+ num_layers=self.df_n_layers,
308
+ batch_first=True,
309
+ gru_skip_op=nn.Identity,
310
+ linear_act_layer=partial(nn.ReLU, inplace=True),
311
+ )
312
+ p.df_gru_skip = p.df_gru_skip.lower()
313
+ assert p.df_gru_skip in ("none", "identity", "groupedlinear")
314
+ self.df_skip: Optional[nn.Module]
315
+ if p.df_gru_skip == "none":
316
+ self.df_skip = None
317
+ elif p.df_gru_skip == "identity":
318
+ assert p.emb_hidden_dim == p.df_hidden_dim, "Dimensions do not match"
319
+ self.df_skip = nn.Identity()
320
+ elif p.df_gru_skip == "groupedlinear":
321
+ self.df_skip = GroupedLinearEinsum(
322
+ p.emb_hidden_dim, p.df_hidden_dim, groups=p.lin_groups
323
+ )
324
+ else:
325
+ raise NotImplementedError()
326
+ assert p.df_output_layer in ("linear", "groupedlinear")
327
+ self.df_out: nn.Module
328
+ out_dim = self.df_bins * self.df_out_ch
329
+ if p.df_output_layer == "linear":
330
+ df_out = nn.Linear(self.df_n_hidden, out_dim)
331
+ elif p.df_output_layer == "groupedlinear":
332
+ df_out = GroupedLinearEinsum(self.df_n_hidden, out_dim, groups=p.lin_groups)
333
+ else:
334
+ raise NotImplementedError
335
+ self.df_out = nn.Sequential(df_out, nn.Tanh())
336
+ self.df_fc_a = nn.Sequential(nn.Linear(self.df_n_hidden, 1), nn.Sigmoid())
337
+ self.out_transform = DfOutputReshapeMF(self.df_order, self.df_bins)
338
+
339
+ def forward(self, emb: Tensor, c0: Tensor) -> Tuple[Tensor, Tensor]:
340
+ b, t, _ = emb.shape
341
+ c, _ = self.df_gru(emb) # [B, T, H], H: df_n_hidden
342
+ if self.df_skip is not None:
343
+ c += self.df_skip(emb)
344
+ c0 = self.df_convp(c0).permute(0, 2, 3, 1) # [B, T, F, O*2], channels_last
345
+ alpha = self.df_fc_a(c) # [B, T, 1]
346
+ c = self.df_out(c) # [B, T, F*O*2], O: df_order
347
+ c = c.view(b, t, self.df_bins, self.df_out_ch) + c0 # [B, T, F, O*2]
348
+ c = self.out_transform(c)
349
+ return c, alpha
350
+
351
+
352
+ class DfNet(nn.Module):
353
+ run_df: Final[bool]
354
+ pad_specf: Final[bool]
355
+
356
+ def __init__(
357
+ self,
358
+ erb_fb: Tensor,
359
+ erb_inv_fb: Tensor,
360
+ run_df: bool = True,
361
+ train_mask: bool = True,
362
+ ):
363
+ super().__init__()
364
+ p = ModelParams()
365
+ layer_width = p.conv_ch
366
+ assert p.nb_erb % 8 == 0, "erb_bins should be divisible by 8"
367
+ self.df_lookahead = p.df_lookahead if p.pad_mode == "model" else 0
368
+ self.nb_df = p.nb_df
369
+ self.freq_bins: int = p.fft_size // 2 + 1
370
+ self.emb_dim: int = layer_width * p.nb_erb
371
+ self.erb_bins: int = p.nb_erb
372
+ if p.conv_lookahead > 0 and p.pad_mode.startswith("input"):
373
+ self.pad_feat = nn.ConstantPad2d((0, 0, -p.conv_lookahead, p.conv_lookahead), 0.0)
374
+ else:
375
+ self.pad_feat = nn.Identity()
376
+ self.pad_specf = p.pad_mode.endswith("specf")
377
+ if p.df_lookahead > 0 and self.pad_specf:
378
+ self.pad_spec = nn.ConstantPad3d((0, 0, 0, 0, -p.df_lookahead, p.df_lookahead), 0.0)
379
+ else:
380
+ self.pad_spec = nn.Identity()
381
+ if (p.conv_lookahead > 0 or p.df_lookahead > 0) and p.pad_mode.startswith("output"):
382
+ assert p.conv_lookahead == p.df_lookahead
383
+ pad = (0, 0, 0, 0, -p.conv_lookahead, p.conv_lookahead)
384
+ self.pad_out = nn.ConstantPad3d(pad, 0.0)
385
+ else:
386
+ self.pad_out = nn.Identity()
387
+ self.register_buffer("erb_fb", erb_fb)
388
+ self.enc = Encoder()
389
+ self.erb_dec = ErbDecoder()
390
+ self.mask = Mask(erb_inv_fb, post_filter=p.mask_pf)
391
+
392
+ self.df_order = p.df_order
393
+ self.df_bins = p.nb_df
394
+ self.df_op: Union[DfOp, MultiFrameModule]
395
+ if p.dfop_method == "real_unfold":
396
+ raise ValueError("RealUnfold DF OP is now unsupported.")
397
+ assert p.df_output_layer != "linear", "Must be used with `groupedlinear`"
398
+ self.df_op = MF_METHODS[p.dfop_method](
399
+ num_freqs=p.nb_df, frame_size=p.df_order, lookahead=self.df_lookahead
400
+ )
401
+ n_ch_out = self.df_op.num_channels()
402
+ self.df_dec = DfDecoder(out_channels=n_ch_out)
403
+
404
+ self.run_df = run_df
405
+ if not run_df:
406
+ logger.warning("Runing without DF")
407
+ self.train_mask = train_mask
408
+ assert p.df_n_iter == 1
409
+
410
+ def forward(
411
+ self,
412
+ spec: Tensor,
413
+ feat_erb: Tensor,
414
+ feat_spec: Tensor, # Not used, take spec modified by mask instead
415
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
416
+ """Forward method of DeepFilterNet2.
417
+
418
+ Args:
419
+ spec (Tensor): Spectrum of shape [B, 1, T, F, 2]
420
+ feat_erb (Tensor): ERB features of shape [B, 1, T, E]
421
+ feat_spec (Tensor): Complex spectrogram features of shape [B, 1, T, F']
422
+
423
+ Returns:
424
+ spec (Tensor): Enhanced spectrum of shape [B, 1, T, F, 2]
425
+ m (Tensor): ERB mask estimate of shape [B, 1, T, E]
426
+ lsnr (Tensor): Local SNR estimate of shape [B, T, 1]
427
+ """
428
+ feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2)
429
+
430
+ feat_erb = self.pad_feat(feat_erb)
431
+ feat_spec = self.pad_feat(feat_spec)
432
+ e0, e1, e2, e3, emb, c0, lsnr = self.enc(feat_erb, feat_spec)
433
+ m = self.erb_dec(emb, e3, e2, e1, e0)
434
+
435
+ m = self.pad_out(m.unsqueeze(-1)).squeeze(-1)
436
+ spec = self.mask(spec, m)
437
+
438
+ if self.run_df:
439
+ df_coefs, df_alpha = self.df_dec(emb, c0)
440
+ df_coefs = self.pad_out(df_coefs)
441
+
442
+ if self.pad_specf:
443
+ # Only pad the lower part of the spectrum.
444
+ spec_f = self.pad_spec(spec)
445
+ spec_f = self.df_op(spec_f, df_coefs)
446
+ spec[..., : self.nb_df, :] = spec_f[..., : self.nb_df, :]
447
+ else:
448
+ spec = self.pad_spec(spec)
449
+ spec = self.df_op(spec, df_coefs)
450
+ else:
451
+ df_alpha = torch.zeros(())
452
+
453
+ return spec, m, lsnr, df_alpha
df/enhance.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ import warnings
5
+ from typing import Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torchaudio as ta
9
+ from loguru import logger
10
+ from numpy import ndarray
11
+ from torch import Tensor, nn
12
+ from torch.nn import functional as F
13
+ from torchaudio.backend.common import AudioMetaData
14
+
15
+ import df
16
+ from df import config
17
+ from df.checkpoint import load_model as load_model_cp
18
+ from df.logger import init_logger, warn_once
19
+ from df.model import ModelParams
20
+ from df.modules import get_device
21
+ from df.utils import as_complex, as_real, get_norm_alpha, resample
22
+ from libdf import DF, erb, erb_norm, unit_norm
23
+
24
+
25
+ def main(args):
26
+ model, df_state, suffix = init_df(
27
+ args.model_base_dir,
28
+ post_filter=args.pf,
29
+ log_level=args.log_level,
30
+ config_allow_defaults=True,
31
+ epoch=args.epoch,
32
+ )
33
+ if args.output_dir is None:
34
+ args.output_dir = "."
35
+ elif not os.path.isdir(args.output_dir):
36
+ os.mkdir(args.output_dir)
37
+ df_sr = ModelParams().sr
38
+ n_samples = len(args.noisy_audio_files)
39
+ for i, file in enumerate(args.noisy_audio_files):
40
+ progress = (i + 1) / n_samples * 100
41
+ audio, meta = load_audio(file, df_sr)
42
+ t0 = time.time()
43
+ audio = enhance(
44
+ model, df_state, audio, pad=args.compensate_delay, atten_lim_db=args.atten_lim
45
+ )
46
+ t1 = time.time()
47
+ t_audio = audio.shape[-1] / df_sr
48
+ t = t1 - t0
49
+ rtf = t / t_audio
50
+ fn = os.path.basename(file)
51
+ p_str = f"{progress:2.0f}% | " if n_samples > 1 else ""
52
+ logger.info(f"{p_str}Enhanced noisy audio file '{fn}' in {t:.1f}s (RT factor: {rtf:.3f})")
53
+ audio = resample(audio, df_sr, meta.sample_rate)
54
+ save_audio(
55
+ file, audio, sr=meta.sample_rate, output_dir=args.output_dir, suffix=suffix, log=False
56
+ )
57
+
58
+
59
+ def init_df(
60
+ model_base_dir: Optional[str] = None,
61
+ post_filter: bool = False,
62
+ log_level: str = "INFO",
63
+ log_file: Optional[str] = "enhance.log",
64
+ config_allow_defaults: bool = False,
65
+ epoch: Union[str, int, None] = "best",
66
+ default_model: str = "DeepFilterNet2",
67
+ ) -> Tuple[nn.Module, DF, str]:
68
+ """Initializes and loads config, model and deep filtering state.
69
+
70
+ Args:
71
+ model_base_dir (str): Path to the model directory containing checkpoint and config. If None,
72
+ load the pretrained DeepFilterNet2 model.
73
+ post_filter (bool): Enable post filter for some minor, extra noise reduction.
74
+ log_level (str): Control amount of logging. Defaults to `INFO`.
75
+ log_file (str): Optional log file name. None disables it. Defaults to `enhance.log`.
76
+ config_allow_defaults (bool): Whether to allow initializing new config values with defaults.
77
+ epoch (str): Checkpoint epoch to load. Options are `best`, `latest`, `<int>`, and `none`.
78
+ `none` disables checkpoint loading. Defaults to `best`.
79
+
80
+ Returns:
81
+ model (nn.Modules): Intialized model, moved to GPU if available.
82
+ df_state (DF): Deep filtering state for stft/istft/erb
83
+ suffix (str): Suffix based on the model name. This can be used for saving the enhanced
84
+ audio.
85
+ """
86
+ try:
87
+ from icecream import ic, install
88
+
89
+ ic.configureOutput(includeContext=True)
90
+ install()
91
+ except ImportError:
92
+ pass
93
+ use_default_model = False
94
+ if model_base_dir == "DeepFilterNet":
95
+ default_model = "DeepFilterNet"
96
+ use_default_model = True
97
+ elif model_base_dir == "DeepFilterNet2":
98
+ use_default_model = True
99
+ if model_base_dir is None or use_default_model:
100
+ use_default_model = True
101
+ model_base_dir = os.path.relpath(
102
+ os.path.join(
103
+ os.path.dirname(df.__file__), os.pardir, "pretrained_models", default_model
104
+ )
105
+ )
106
+ if not os.path.isdir(model_base_dir):
107
+ raise NotADirectoryError("Base directory not found at {}".format(model_base_dir))
108
+ log_file = os.path.join(model_base_dir, log_file) if log_file is not None else None
109
+ init_logger(file=log_file, level=log_level, model=model_base_dir)
110
+ if use_default_model:
111
+ logger.info(f"Using {default_model} model at {model_base_dir}")
112
+ config.load(
113
+ os.path.join(model_base_dir, "config.ini"),
114
+ config_must_exist=True,
115
+ allow_defaults=config_allow_defaults,
116
+ allow_reload=True,
117
+ )
118
+ if post_filter:
119
+ config.set("mask_pf", True, bool, ModelParams().section)
120
+ logger.info("Running with post-filter")
121
+ p = ModelParams()
122
+ df_state = DF(
123
+ sr=p.sr,
124
+ fft_size=p.fft_size,
125
+ hop_size=p.hop_size,
126
+ nb_bands=p.nb_erb,
127
+ min_nb_erb_freqs=p.min_nb_freqs,
128
+ )
129
+ checkpoint_dir = os.path.join(model_base_dir, "checkpoints")
130
+ load_cp = epoch is not None and not (isinstance(epoch, str) and epoch.lower() == "none")
131
+ if not load_cp:
132
+ checkpoint_dir = None
133
+ try:
134
+ mask_only = config.get("mask_only", cast=bool, section="train")
135
+ except KeyError:
136
+ mask_only = False
137
+ model, epoch = load_model_cp(checkpoint_dir, df_state, epoch=epoch, mask_only=mask_only)
138
+ if (epoch is None or epoch == 0) and load_cp:
139
+ logger.error("Could not find a checkpoint")
140
+ exit(1)
141
+ logger.debug(f"Loaded checkpoint from epoch {epoch}")
142
+ model = model.to(get_device())
143
+ # Set suffix to model name
144
+ suffix = os.path.basename(os.path.abspath(model_base_dir))
145
+ if post_filter:
146
+ suffix += "_pf"
147
+ logger.info("Model loaded")
148
+ return model, df_state, suffix
149
+
150
+
151
+ def df_features(audio: Tensor, df: DF, nb_df: int, device=None) -> Tuple[Tensor, Tensor, Tensor]:
152
+ spec = df.analysis(audio.numpy()) # [C, Tf] -> [C, Tf, F]
153
+ a = get_norm_alpha(False)
154
+ erb_fb = df.erb_widths()
155
+ with warnings.catch_warnings():
156
+ warnings.simplefilter("ignore", UserWarning)
157
+ erb_feat = torch.as_tensor(erb_norm(erb(spec, erb_fb), a)).unsqueeze(1)
158
+ spec_feat = as_real(torch.as_tensor(unit_norm(spec[..., :nb_df], a)).unsqueeze(1))
159
+ spec = as_real(torch.as_tensor(spec).unsqueeze(1))
160
+ if device is not None:
161
+ spec = spec.to(device)
162
+ erb_feat = erb_feat.to(device)
163
+ spec_feat = spec_feat.to(device)
164
+ return spec, erb_feat, spec_feat
165
+
166
+
167
+ def load_audio(
168
+ file: str, sr: Optional[int], verbose=True, **kwargs
169
+ ) -> Tuple[Tensor, AudioMetaData]:
170
+ """Loads an audio file using torchaudio.
171
+
172
+ Args:
173
+ file (str): Path to an audio file.
174
+ sr (int): Optionally resample audio to specified target sampling rate.
175
+ **kwargs: Passed to torchaudio.load(). Depends on the backend. The resample method
176
+ may be set via `method` which is passed to `resample()`.
177
+
178
+ Returns:
179
+ audio (Tensor): Audio tensor of shape [C, T], if channels_first=True (default).
180
+ info (AudioMetaData): Meta data of the original audio file. Contains the original sr.
181
+ """
182
+ ikwargs = {}
183
+ if "format" in kwargs:
184
+ ikwargs["format"] = kwargs["format"]
185
+ rkwargs = {}
186
+ if "method" in kwargs:
187
+ rkwargs["method"] = kwargs.pop("method")
188
+ info: AudioMetaData = ta.info(file, **ikwargs)
189
+ audio, orig_sr = ta.load(file, **kwargs)
190
+ if sr is not None and orig_sr != sr:
191
+ if verbose:
192
+ warn_once(
193
+ f"Audio sampling rate does not match model sampling rate ({orig_sr}, {sr}). "
194
+ "Resampling..."
195
+ )
196
+ audio = resample(audio, orig_sr, sr, **rkwargs)
197
+ return audio, info
198
+
199
+
200
+ def save_audio(
201
+ file: str,
202
+ audio: Union[Tensor, ndarray],
203
+ sr: int,
204
+ output_dir: Optional[str] = None,
205
+ suffix: Optional[str] = None,
206
+ log: bool = False,
207
+ dtype=torch.int16,
208
+ ):
209
+ outpath = file
210
+ if suffix is not None:
211
+ file, ext = os.path.splitext(file)
212
+ outpath = file + f"_{suffix}" + ext
213
+ if output_dir is not None:
214
+ outpath = os.path.join(output_dir, os.path.basename(outpath))
215
+ if log:
216
+ logger.info(f"Saving audio file '{outpath}'")
217
+ audio = torch.as_tensor(audio)
218
+ if audio.ndim == 1:
219
+ audio.unsqueeze_(0)
220
+ if dtype == torch.int16 and audio.dtype != torch.int16:
221
+ audio = (audio * (1 << 15)).to(torch.int16)
222
+ if dtype == torch.float32 and audio.dtype != torch.float32:
223
+ audio = audio.to(torch.float32) / (1 << 15)
224
+ ta.save(outpath, audio, sr)
225
+
226
+
227
+ @torch.no_grad()
228
+ def enhance(
229
+ model: nn.Module, df_state: DF, audio: Tensor, pad=False, atten_lim_db: Optional[float] = None
230
+ ):
231
+ model.eval()
232
+ bs = audio.shape[0]
233
+ if hasattr(model, "reset_h0"):
234
+ model.reset_h0(batch_size=bs, device=get_device())
235
+ orig_len = audio.shape[-1]
236
+ n_fft, hop = 0, 0
237
+ if pad:
238
+ n_fft, hop = df_state.fft_size(), df_state.hop_size()
239
+ # Pad audio to compensate for the delay due to the real-time STFT implementation
240
+ audio = F.pad(audio, (0, n_fft))
241
+ nb_df = getattr(model, "nb_df", getattr(model, "df_bins", ModelParams().nb_df))
242
+ spec, erb_feat, spec_feat = df_features(audio, df_state, nb_df, device=get_device())
243
+ enhanced = model(spec, erb_feat, spec_feat)[0].cpu()
244
+ enhanced = as_complex(enhanced.squeeze(1))
245
+ if atten_lim_db is not None and abs(atten_lim_db) > 0:
246
+ lim = 10 ** (-abs(atten_lim_db) / 20)
247
+ enhanced = as_complex(spec.squeeze(1)) * lim + enhanced * (1 - lim)
248
+ audio = torch.as_tensor(df_state.synthesis(enhanced.numpy()))
249
+ if pad:
250
+ # The frame size is equal to p.hop_size. Given a new frame, the STFT loop requires e.g.
251
+ # ceil((n_fft-hop)/hop). I.e. for 50% overlap, then hop=n_fft//2
252
+ # requires 1 additional frame lookahead; 75% requires 3 additional frames lookahead.
253
+ # Thus, the STFT/ISTFT loop introduces an algorithmic delay of n_fft - hop.
254
+ assert n_fft % hop == 0 # This is only tested for 50% and 75% overlap
255
+ d = n_fft - hop
256
+ audio = audio[:, d : orig_len + d]
257
+ return audio
258
+
259
+
260
+ def parse_epoch_type(value: str) -> Union[int, str]:
261
+ try:
262
+ return int(value)
263
+ except ValueError:
264
+ assert value in ("best", "latest")
265
+ return value
266
+
267
+
268
+ def setup_df_argument_parser(default_log_level: str = "INFO") -> argparse.ArgumentParser:
269
+ parser = argparse.ArgumentParser()
270
+ parser.add_argument(
271
+ "--model-base-dir",
272
+ "-m",
273
+ type=str,
274
+ default=None,
275
+ help="Model directory containing checkpoints and config. "
276
+ "To load a pretrained model, you may just provide the model name, e.g. `DeepFilterNet`. "
277
+ "By default, the pretrained DeepFilterNet2 model is loaded.",
278
+ )
279
+ parser.add_argument(
280
+ "--pf",
281
+ help="Post-filter that slightly over-attenuates very noisy sections.",
282
+ action="store_true",
283
+ )
284
+ parser.add_argument(
285
+ "--output-dir",
286
+ "-o",
287
+ type=str,
288
+ default=None,
289
+ help="Directory in which the enhanced audio files will be stored.",
290
+ )
291
+ parser.add_argument(
292
+ "--log-level",
293
+ type=str,
294
+ default=default_log_level,
295
+ help="Logger verbosity. Can be one of (debug, info, error, none)",
296
+ )
297
+ parser.add_argument("--debug", "-d", action="store_const", const="DEBUG", dest="log_level")
298
+ parser.add_argument(
299
+ "--epoch",
300
+ "-e",
301
+ default="best",
302
+ type=parse_epoch_type,
303
+ help="Epoch for checkpoint loading. Can be one of ['best', 'latest', <int>].",
304
+ )
305
+ return parser
306
+
307
+
308
+ def run():
309
+ parser = setup_df_argument_parser()
310
+ parser.add_argument(
311
+ "--compensate-delay",
312
+ "-D",
313
+ action="store_true",
314
+ help="Add some paddig to compensate the delay introduced by the real-time STFT/ISTFT implementation.",
315
+ )
316
+ parser.add_argument(
317
+ "--atten-lim",
318
+ "-a",
319
+ type=int,
320
+ default=None,
321
+ help="Attenuation limit in dB by mixing the enhanced signal with the noisy signal.",
322
+ )
323
+ parser.add_argument(
324
+ "noisy_audio_files",
325
+ type=str,
326
+ nargs="+",
327
+ help="List of noise files to mix with the clean speech file.",
328
+ )
329
+ main(parser.parse_args())
330
+
331
+
332
+ if __name__ == "__main__":
333
+ run()
df/logger.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import warnings
4
+ from collections import defaultdict
5
+ from copy import deepcopy
6
+ from typing import Dict, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+ from loguru import logger
11
+ from torch.types import Number
12
+
13
+ from df.modules import GroupedLinearEinsum
14
+ from df.utils import get_branch_name, get_commit_hash, get_device, get_host
15
+
16
+ _logger_initialized = False
17
+ WARN_ONCE_NO = logger.level("WARNING").no + 1
18
+ DEPRECATED_NO = logger.level("WARNING").no + 2
19
+
20
+
21
+ def init_logger(file: Optional[str] = None, level: str = "INFO", model: Optional[str] = None):
22
+ global _logger_initialized, _duplicate_filter
23
+ if _logger_initialized:
24
+ logger.debug("Logger already initialized.")
25
+ else:
26
+ logger.remove()
27
+ level = level.upper()
28
+ if level.lower() != "none":
29
+ log_format = Formatter(debug=logger.level(level).no <= logger.level("DEBUG").no).format
30
+ logger.add(
31
+ sys.stdout,
32
+ level=level,
33
+ format=log_format,
34
+ filter=lambda r: r["level"].no not in {WARN_ONCE_NO, DEPRECATED_NO},
35
+ )
36
+ if file is not None:
37
+ logger.add(
38
+ file,
39
+ level=level,
40
+ format=log_format,
41
+ filter=lambda r: r["level"].no != WARN_ONCE_NO,
42
+ )
43
+
44
+ logger.info(f"Running on torch {torch.__version__}")
45
+ logger.info(f"Running on host {get_host()}")
46
+ commit = get_commit_hash()
47
+ if commit is not None:
48
+ logger.info(f"Git commit: {commit}, branch: {get_branch_name()}")
49
+ if (jobid := os.getenv("SLURM_JOB_ID")) is not None:
50
+ logger.info(f"Slurm jobid: {jobid}")
51
+ logger.level("WARNONCE", no=WARN_ONCE_NO, color="<yellow><bold>")
52
+ logger.add(
53
+ sys.stderr,
54
+ level=max(logger.level(level).no, WARN_ONCE_NO),
55
+ format=log_format,
56
+ filter=lambda r: r["level"].no == WARN_ONCE_NO and _duplicate_filter(r),
57
+ )
58
+ logger.level("DEPRECATED", no=DEPRECATED_NO, color="<yellow><bold>")
59
+ logger.add(
60
+ sys.stderr,
61
+ level=max(logger.level(level).no, DEPRECATED_NO),
62
+ format=log_format,
63
+ filter=lambda r: r["level"].no == DEPRECATED_NO and _duplicate_filter(r),
64
+ )
65
+ if model is not None:
66
+ logger.info("Loading model settings of {}", os.path.basename(model.rstrip("/")))
67
+ _logger_initialized = True
68
+
69
+
70
+ def warn_once(message, *args, **kwargs):
71
+ logger.log("WARNONCE", message, *args, **kwargs)
72
+
73
+
74
+ def log_deprecated(message, *args, **kwargs):
75
+ logger.log("DEPRECATED", message, *args, **kwargs)
76
+
77
+
78
+ class Formatter:
79
+ def __init__(self, debug=False):
80
+ if debug:
81
+ self.fmt = (
82
+ "<green>{time:YYYY-MM-DD HH:mm:ss}</green>"
83
+ " | <level>{level: <8}</level>"
84
+ " | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan>"
85
+ " | <level>{message}</level>"
86
+ )
87
+ else:
88
+ self.fmt = (
89
+ "<green>{time:YYYY-MM-DD HH:mm:ss}</green>"
90
+ " | <level>{level: <8}</level>"
91
+ " | <cyan>DF</cyan>"
92
+ " | <level>{message}</level>"
93
+ )
94
+ self.fmt += "\n{exception}"
95
+
96
+ def format(self, record):
97
+ if record["level"].no == WARN_ONCE_NO:
98
+ return self.fmt.replace("{level: <8}", "WARNING ")
99
+ return self.fmt
100
+
101
+
102
+ def _metrics_key(k_: Tuple[str, float]):
103
+ k0 = k_[0]
104
+ ks = k0.split("_")
105
+ if len(ks) > 2:
106
+ try:
107
+ return int(ks[-1])
108
+ except ValueError:
109
+ return 1000
110
+ elif k0 == "loss":
111
+ return -999
112
+ elif "loss" in k0.lower():
113
+ return -998
114
+ elif k0 == "lr":
115
+ return 998
116
+ elif k0 == "wd":
117
+ return 999
118
+ else:
119
+ return -101
120
+
121
+
122
+ def log_metrics(prefix: str, metrics: Dict[str, Number], level="INFO"):
123
+ msg = ""
124
+ stages = defaultdict(str)
125
+ loss_msg = ""
126
+ for n, v in sorted(metrics.items(), key=_metrics_key):
127
+ if abs(v) > 1e-3:
128
+ m = f" | {n}: {v:.5f}"
129
+ else:
130
+ m = f" | {n}: {v:.3E}"
131
+ if "stage" in n:
132
+ s = n.split("stage_")[1].split("_snr")[0]
133
+ stages[s] += m.replace(f"stage_{s}_", "")
134
+ elif ("valid" in prefix or "test" in prefix) and "loss" in n.lower():
135
+ loss_msg += m
136
+ else:
137
+ msg += m
138
+ for s, msg_s in stages.items():
139
+ logger.log(level, f"{prefix} | stage {s}" + msg_s)
140
+ if len(stages) == 0:
141
+ logger.log(level, prefix + msg)
142
+ if len(loss_msg) > 0:
143
+ logger.log(level, prefix + loss_msg)
144
+
145
+
146
+ class DuplicateFilter:
147
+ """
148
+ Filters away duplicate log messages.
149
+ Modified version of: https://stackoverflow.com/a/60462619
150
+ """
151
+
152
+ def __init__(self):
153
+ self.msgs = set()
154
+
155
+ def __call__(self, record) -> bool:
156
+ k = f"{record['level']}{record['message']}"
157
+ if k in self.msgs:
158
+ return False
159
+ else:
160
+ self.msgs.add(k)
161
+ return True
162
+
163
+
164
+ _duplicate_filter = DuplicateFilter()
165
+
166
+
167
+ def log_model_summary(model: torch.nn.Module, verbose=False):
168
+ try:
169
+ import ptflops
170
+ except ImportError:
171
+ logger.debug("Failed to import ptflops. Cannot print model summary.")
172
+ return
173
+
174
+ from df.model import ModelParams
175
+
176
+ # Generate input of 1 second audio
177
+ # Necessary inputs are:
178
+ # spec: [B, 1, T, F, 2], F: freq bin
179
+ # feat_erb: [B, 1, T, E], E: ERB bands
180
+ # feat_spec: [B, 2, T, C*2], C: Complex features
181
+ p = ModelParams()
182
+ b = 1
183
+ t = p.sr // p.hop_size
184
+ device = get_device()
185
+ spec = torch.randn([b, 1, t, p.fft_size // 2 + 1, 2]).to(device)
186
+ feat_erb = torch.randn([b, 1, t, p.nb_erb]).to(device)
187
+ feat_spec = torch.randn([b, 1, t, p.nb_df, 2]).to(device)
188
+
189
+ warnings.filterwarnings("ignore", "RNN module weights", category=UserWarning, module="torch")
190
+ macs, params = ptflops.get_model_complexity_info(
191
+ deepcopy(model),
192
+ (t,),
193
+ input_constructor=lambda _: {"spec": spec, "feat_erb": feat_erb, "feat_spec": feat_spec},
194
+ as_strings=False,
195
+ print_per_layer_stat=verbose,
196
+ verbose=verbose,
197
+ custom_modules_hooks={
198
+ GroupedLinearEinsum: grouped_linear_flops_counter_hook,
199
+ },
200
+ )
201
+ logger.info(f"Model complexity: {params/1e6:.3f}M #Params, {macs/1e6:.1f}M MACS")
202
+
203
+
204
+ def grouped_linear_flops_counter_hook(module: GroupedLinearEinsum, input, output):
205
+ # input: ([B, T, I],)
206
+ # output: [B, T, H]
207
+ input = input[0] # [B, T, I]
208
+ output_last_dim = module.weight.shape[-1]
209
+ input = input.unflatten(-1, (module.groups, module.ws)) # [B, T, G, I/G]
210
+ # GroupedLinear calculates "...gi,...gih->...gh"
211
+ weight_flops = np.prod(input.shape) * output_last_dim
212
+ module.__flops__ += int(weight_flops) # type: ignore
df/model.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+
3
+ import torch
4
+ from loguru import logger
5
+
6
+ from df.config import DfParams, config
7
+
8
+
9
+ class ModelParams(DfParams):
10
+ def __init__(self):
11
+ self.__model = config("MODEL", default="deepfilternet", section="train")
12
+ self.__params = getattr(import_module("df." + self.__model), "ModelParams")()
13
+
14
+ def __getattr__(self, attr: str):
15
+ return getattr(self.__params, attr)
16
+
17
+
18
+ def init_model(*args, **kwargs):
19
+ """Initialize the model specified in the config."""
20
+ model = config("MODEL", default="deepfilternet", section="train")
21
+ logger.info(f"Initializing model `{model}`")
22
+ model = getattr(import_module("df." + model), "init_model")(*args, **kwargs)
23
+ model.to(memory_format=torch.channels_last)
24
+ return model
df/modules.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import OrderedDict
3
+ from typing import Callable, Iterable, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import Tensor, nn
8
+ from torch.nn import functional as F
9
+ from torch.nn import init
10
+ from torch.nn.parameter import Parameter
11
+ from typing_extensions import Final
12
+
13
+ from df.model import ModelParams
14
+ from df.utils import as_complex, as_real, get_device, get_norm_alpha
15
+ from libdf import unit_norm_init
16
+
17
+
18
+ class Conv2dNormAct(nn.Sequential):
19
+ def __init__(
20
+ self,
21
+ in_ch: int,
22
+ out_ch: int,
23
+ kernel_size: Union[int, Iterable[int]],
24
+ fstride: int = 1,
25
+ dilation: int = 1,
26
+ fpad: bool = True,
27
+ bias: bool = True,
28
+ separable: bool = False,
29
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
30
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
31
+ ):
32
+ """Causal Conv2d by delaying the signal for any lookahead.
33
+
34
+ Expected input format: [B, C, T, F]
35
+ """
36
+ lookahead = 0 # This needs to be handled on the input feature side
37
+ # Padding on time axis
38
+ kernel_size = (
39
+ (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
40
+ )
41
+ if fpad:
42
+ fpad_ = kernel_size[1] // 2 + dilation - 1
43
+ else:
44
+ fpad_ = 0
45
+ pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
46
+ layers = []
47
+ if any(x > 0 for x in pad):
48
+ layers.append(nn.ConstantPad2d(pad, 0.0))
49
+ groups = math.gcd(in_ch, out_ch) if separable else 1
50
+ if groups == 1:
51
+ separable = False
52
+ if max(kernel_size) == 1:
53
+ separable = False
54
+ layers.append(
55
+ nn.Conv2d(
56
+ in_ch,
57
+ out_ch,
58
+ kernel_size=kernel_size,
59
+ padding=(0, fpad_),
60
+ stride=(1, fstride), # Stride over time is always 1
61
+ dilation=(1, dilation), # Same for dilation
62
+ groups=groups,
63
+ bias=bias,
64
+ )
65
+ )
66
+ if separable:
67
+ layers.append(nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False))
68
+ if norm_layer is not None:
69
+ layers.append(norm_layer(out_ch))
70
+ if activation_layer is not None:
71
+ layers.append(activation_layer())
72
+ super().__init__(*layers)
73
+
74
+
75
+ class ConvTranspose2dNormAct(nn.Sequential):
76
+ def __init__(
77
+ self,
78
+ in_ch: int,
79
+ out_ch: int,
80
+ kernel_size: Union[int, Tuple[int, int]],
81
+ fstride: int = 1,
82
+ dilation: int = 1,
83
+ fpad: bool = True,
84
+ bias: bool = True,
85
+ separable: bool = False,
86
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
87
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
88
+ ):
89
+ """Causal ConvTranspose2d.
90
+
91
+ Expected input format: [B, C, T, F]
92
+ """
93
+ # Padding on time axis, with lookahead = 0
94
+ lookahead = 0 # This needs to be handled on the input feature side
95
+ kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
96
+ if fpad:
97
+ fpad_ = kernel_size[1] // 2
98
+ else:
99
+ fpad_ = 0
100
+ pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
101
+ layers = []
102
+ if any(x > 0 for x in pad):
103
+ layers.append(nn.ConstantPad2d(pad, 0.0))
104
+ groups = math.gcd(in_ch, out_ch) if separable else 1
105
+ if groups == 1:
106
+ separable = False
107
+ layers.append(
108
+ nn.ConvTranspose2d(
109
+ in_ch,
110
+ out_ch,
111
+ kernel_size=kernel_size,
112
+ padding=(kernel_size[0] - 1, fpad_ + dilation - 1),
113
+ output_padding=(0, fpad_),
114
+ stride=(1, fstride), # Stride over time is always 1
115
+ dilation=(1, dilation),
116
+ groups=groups,
117
+ bias=bias,
118
+ )
119
+ )
120
+ if separable:
121
+ layers.append(nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False))
122
+ if norm_layer is not None:
123
+ layers.append(norm_layer(out_ch))
124
+ if activation_layer is not None:
125
+ layers.append(activation_layer())
126
+ super().__init__(*layers)
127
+
128
+
129
+ def convkxf(
130
+ in_ch: int,
131
+ out_ch: Optional[int] = None,
132
+ k: int = 1,
133
+ f: int = 3,
134
+ fstride: int = 2,
135
+ lookahead: int = 0,
136
+ batch_norm: bool = False,
137
+ act: nn.Module = nn.ReLU(inplace=True),
138
+ mode="normal", # must be "normal", "transposed" or "upsample"
139
+ depthwise: bool = True,
140
+ complex_in: bool = False,
141
+ ):
142
+ bias = batch_norm is False
143
+ assert f % 2 == 1
144
+ stride = 1 if f == 1 else (1, fstride)
145
+ if out_ch is None:
146
+ out_ch = in_ch * 2 if mode == "normal" else in_ch // 2
147
+ fpad = (f - 1) // 2
148
+ convpad = (0, fpad)
149
+ modules = []
150
+ # Manually pad for time axis kernel to not introduce delay
151
+ pad = (0, 0, k - 1 - lookahead, lookahead)
152
+ if any(p > 0 for p in pad):
153
+ modules.append(("pad", nn.ConstantPad2d(pad, 0.0)))
154
+ if depthwise:
155
+ groups = min(in_ch, out_ch)
156
+ else:
157
+ groups = 1
158
+ if in_ch % groups != 0 or out_ch % groups != 0:
159
+ groups = 1
160
+ if complex_in and groups % 2 == 0:
161
+ groups //= 2
162
+ convkwargs = {
163
+ "in_channels": in_ch,
164
+ "out_channels": out_ch,
165
+ "kernel_size": (k, f),
166
+ "stride": stride,
167
+ "groups": groups,
168
+ "bias": bias,
169
+ }
170
+ if mode == "normal":
171
+ modules.append(("sconv", nn.Conv2d(padding=convpad, **convkwargs)))
172
+ elif mode == "transposed":
173
+ # Since pytorch's transposed conv padding does not correspond to the actual padding but
174
+ # rather the padding that was used in the encoder conv, we need to set time axis padding
175
+ # according to k. E.g., this disables the padding for k=2:
176
+ # dilation - (k - 1) - padding
177
+ # = 1 - (2 - 1) - 1 = 0; => padding = fpad (=1 for k=2)
178
+ padding = (k - 1, fpad)
179
+ modules.append(
180
+ ("sconvt", nn.ConvTranspose2d(padding=padding, output_padding=convpad, **convkwargs))
181
+ )
182
+ elif mode == "upsample":
183
+ modules.append(("upsample", FreqUpsample(fstride)))
184
+ convkwargs["stride"] = 1
185
+ modules.append(("sconv", nn.Conv2d(padding=convpad, **convkwargs)))
186
+ else:
187
+ raise NotImplementedError()
188
+ if groups > 1:
189
+ modules.append(("1x1conv", nn.Conv2d(out_ch, out_ch, 1, bias=False)))
190
+ if batch_norm:
191
+ modules.append(("norm", nn.BatchNorm2d(out_ch)))
192
+ modules.append(("act", act))
193
+ return nn.Sequential(OrderedDict(modules))
194
+
195
+
196
+ class FreqUpsample(nn.Module):
197
+ def __init__(self, factor: int, mode="nearest"):
198
+ super().__init__()
199
+ self.f = float(factor)
200
+ self.mode = mode
201
+
202
+ def forward(self, x: Tensor) -> Tensor:
203
+ return F.interpolate(x, scale_factor=[1.0, self.f], mode=self.mode)
204
+
205
+
206
+ def erb_fb(widths: np.ndarray, sr: int, normalized: bool = True, inverse: bool = False) -> Tensor:
207
+ n_freqs = int(np.sum(widths))
208
+ all_freqs = torch.linspace(0, sr // 2, n_freqs + 1)[:-1]
209
+
210
+ b_pts = np.cumsum([0] + widths.tolist()).astype(int)[:-1]
211
+
212
+ fb = torch.zeros((all_freqs.shape[0], b_pts.shape[0]))
213
+ for i, (b, w) in enumerate(zip(b_pts.tolist(), widths.tolist())):
214
+ fb[b : b + w, i] = 1
215
+ # Normalize to constant energy per resulting band
216
+ if inverse:
217
+ fb = fb.t()
218
+ if not normalized:
219
+ fb /= fb.sum(dim=1, keepdim=True)
220
+ else:
221
+ if normalized:
222
+ fb /= fb.sum(dim=0)
223
+ return fb.to(device=get_device())
224
+
225
+
226
+ class Mask(nn.Module):
227
+ def __init__(self, erb_inv_fb: Tensor, post_filter: bool = False, eps: float = 1e-12):
228
+ super().__init__()
229
+ self.erb_inv_fb: Tensor
230
+ self.register_buffer("erb_inv_fb", erb_inv_fb)
231
+ self.clamp_tensor = torch.__version__ > "1.9.0" or torch.__version__ == "1.9.0"
232
+ self.post_filter = post_filter
233
+ self.eps = eps
234
+
235
+ def pf(self, mask: Tensor, beta: float = 0.02) -> Tensor:
236
+ """Post-Filter proposed by Valin et al. [1].
237
+
238
+ Args:
239
+ mask (Tensor): Real valued mask, typically of shape [B, C, T, F].
240
+ beta: Global gain factor.
241
+ Refs:
242
+ [1]: Valin et al.: A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
243
+ """
244
+ mask_sin = mask * torch.sin(np.pi * mask / 2)
245
+ mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
246
+ return mask_pf
247
+
248
+ def forward(self, spec: Tensor, mask: Tensor, atten_lim: Optional[Tensor] = None) -> Tensor:
249
+ # spec (real) [B, 1, T, F, 2], F: freq_bins
250
+ # mask (real): [B, 1, T, Fe], Fe: erb_bins
251
+ # atten_lim: [B]
252
+ if not self.training and self.post_filter:
253
+ mask = self.pf(mask)
254
+ if atten_lim is not None:
255
+ # dB to amplitude
256
+ atten_lim = 10 ** (-atten_lim / 20)
257
+ # Greater equal (__ge__) not implemented for TorchVersion.
258
+ if self.clamp_tensor:
259
+ # Supported by torch >= 1.9
260
+ mask = mask.clamp(min=atten_lim.view(-1, 1, 1, 1))
261
+ else:
262
+ m_out = []
263
+ for i in range(atten_lim.shape[0]):
264
+ m_out.append(mask[i].clamp_min(atten_lim[i].item()))
265
+ mask = torch.stack(m_out, dim=0)
266
+ mask = mask.matmul(self.erb_inv_fb) # [B, 1, T, F]
267
+ return spec * mask.unsqueeze(4)
268
+
269
+
270
+ class ExponentialUnitNorm(nn.Module):
271
+ """Unit norm for a complex spectrogram.
272
+
273
+ This should match the rust code:
274
+ ```rust
275
+ for (x, s) in xs.iter_mut().zip(state.iter_mut()) {
276
+ *s = x.norm() * (1. - alpha) + *s * alpha;
277
+ *x /= s.sqrt();
278
+ }
279
+ ```
280
+ """
281
+
282
+ alpha: Final[float]
283
+ eps: Final[float]
284
+
285
+ def __init__(self, alpha: float, num_freq_bins: int, eps: float = 1e-14):
286
+ super().__init__()
287
+ self.alpha = alpha
288
+ self.eps = eps
289
+ self.init_state: Tensor
290
+ s = torch.from_numpy(unit_norm_init(num_freq_bins)).view(1, 1, num_freq_bins, 1)
291
+ self.register_buffer("init_state", s)
292
+
293
+ def forward(self, x: Tensor) -> Tensor:
294
+ # x: [B, C, T, F, 2]
295
+ b, c, t, f, _ = x.shape
296
+ x_abs = x.square().sum(dim=-1, keepdim=True).clamp_min(self.eps).sqrt()
297
+ state = self.init_state.clone().expand(b, c, f, 1)
298
+ out_states: List[Tensor] = []
299
+ for t in range(t):
300
+ state = x_abs[:, :, t] * (1 - self.alpha) + state * self.alpha
301
+ out_states.append(state)
302
+ return x / torch.stack(out_states, 2).sqrt()
303
+
304
+
305
+ class DfOp(nn.Module):
306
+ df_order: Final[int]
307
+ df_bins: Final[int]
308
+ df_lookahead: Final[int]
309
+ freq_bins: Final[int]
310
+
311
+ def __init__(
312
+ self,
313
+ df_bins: int,
314
+ df_order: int = 5,
315
+ df_lookahead: int = 0,
316
+ method: str = "complex_strided",
317
+ freq_bins: int = 0,
318
+ ):
319
+ super().__init__()
320
+ self.df_order = df_order
321
+ self.df_bins = df_bins
322
+ self.df_lookahead = df_lookahead
323
+ self.freq_bins = freq_bins
324
+ self.set_forward(method)
325
+
326
+ def set_forward(self, method: str):
327
+ # All forward methods should be mathematically similar.
328
+ # DeepFilterNet results are obtained with 'real_unfold'.
329
+ forward_methods = {
330
+ "real_loop": self.forward_real_loop,
331
+ "real_strided": self.forward_real_strided,
332
+ "real_unfold": self.forward_real_unfold,
333
+ "complex_strided": self.forward_complex_strided,
334
+ "real_one_step": self.forward_real_no_pad_one_step,
335
+ "real_hidden_state_loop": self.forward_real_hidden_state_loop,
336
+ }
337
+ if method not in forward_methods.keys():
338
+ raise NotImplementedError(f"`method` must be one of {forward_methods.keys()}")
339
+ if method == "real_hidden_state_loop":
340
+ assert self.freq_bins >= self.df_bins
341
+ self.spec_buf: Tensor
342
+ # Currently only designed for batch size of 1
343
+ self.register_buffer(
344
+ "spec_buf", torch.zeros(1, 1, self.df_order, self.freq_bins, 2), persistent=False
345
+ )
346
+ self.forward = forward_methods[method]
347
+
348
+ def forward_real_loop(
349
+ self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None
350
+ ) -> Tensor:
351
+ # Version 0: Manual loop over df_order, maybe best for onnx export?
352
+ b, _, t, _, _ = spec.shape
353
+ f = self.df_bins
354
+ padded = spec_pad(
355
+ spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3
356
+ )
357
+
358
+ spec_f = torch.zeros((b, t, f, 2), device=spec.device)
359
+ for i in range(self.df_order):
360
+ spec_f[..., 0] += padded[:, i : i + t, ..., 0] * coefs[:, :, i, :, 0]
361
+ spec_f[..., 0] -= padded[:, i : i + t, ..., 1] * coefs[:, :, i, :, 1]
362
+ spec_f[..., 1] += padded[:, i : i + t, ..., 1] * coefs[:, :, i, :, 0]
363
+ spec_f[..., 1] += padded[:, i : i + t, ..., 0] * coefs[:, :, i, :, 1]
364
+ return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha)
365
+
366
+ def forward_real_strided(
367
+ self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None
368
+ ) -> Tensor:
369
+ # Version1: Use as_strided instead of unfold
370
+ # spec (real) [B, 1, T, F, 2], O: df_order
371
+ # coefs (real) [B, T, O, F, 2]
372
+ # alpha (real) [B, T, 1]
373
+ padded = as_strided(
374
+ spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3
375
+ )
376
+ # Complex numbers are not supported by onnx
377
+ re = padded[..., 0] * coefs[..., 0]
378
+ re -= padded[..., 1] * coefs[..., 1]
379
+ im = padded[..., 1] * coefs[..., 0]
380
+ im += padded[..., 0] * coefs[..., 1]
381
+ spec_f = torch.stack((re, im), -1).sum(2)
382
+ return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha)
383
+
384
+ def forward_real_unfold(
385
+ self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None
386
+ ) -> Tensor:
387
+ # Version2: Unfold
388
+ # spec (real) [B, 1, T, F, 2], O: df_order
389
+ # coefs (real) [B, T, O, F, 2]
390
+ # alpha (real) [B, T, 1]
391
+ padded = spec_pad(
392
+ spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3
393
+ )
394
+ padded = padded.unfold(dimension=1, size=self.df_order, step=1) # [B, T, F, 2, O]
395
+ padded = padded.permute(0, 1, 4, 2, 3)
396
+ spec_f = torch.empty_like(padded)
397
+ spec_f[..., 0] = padded[..., 0] * coefs[..., 0] # re1
398
+ spec_f[..., 0] -= padded[..., 1] * coefs[..., 1] # re2
399
+ spec_f[..., 1] = padded[..., 1] * coefs[..., 0] # im1
400
+ spec_f[..., 1] += padded[..., 0] * coefs[..., 1] # im2
401
+ spec_f = spec_f.sum(dim=2)
402
+ return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha)
403
+
404
+ def forward_complex_strided(
405
+ self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None
406
+ ) -> Tensor:
407
+ # Version3: Complex strided; definatly nicest, no permute, no indexing, but complex gradient
408
+ # spec (real) [B, 1, T, F, 2], O: df_order
409
+ # coefs (real) [B, T, O, F, 2]
410
+ # alpha (real) [B, T, 1]
411
+ padded = as_strided(
412
+ spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3
413
+ )
414
+ spec_f = torch.sum(torch.view_as_complex(padded) * torch.view_as_complex(coefs), dim=2)
415
+ spec_f = torch.view_as_real(spec_f)
416
+ return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha)
417
+
418
+ def forward_real_no_pad_one_step(
419
+ self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None
420
+ ) -> Tensor:
421
+ # Version4: Only viable for onnx handling. `spec` needs external (ring-)buffer handling.
422
+ # Thus, time steps `t` must be equal to `df_order`.
423
+
424
+ # spec (real) [B, 1, O, F', 2]
425
+ # coefs (real) [B, 1, O, F, 2]
426
+ assert (
427
+ spec.shape[2] == self.df_order
428
+ ), "This forward method needs spectrogram buffer with `df_order` time steps as input"
429
+ assert coefs.shape[1] == 1, "This forward method is only valid for 1 time step"
430
+ sre, sim = spec[..., : self.df_bins, :].split(1, -1)
431
+ cre, cim = coefs.split(1, -1)
432
+ outr = torch.sum(sre * cre - sim * cim, dim=2).squeeze(-1)
433
+ outi = torch.sum(sre * cim + sim * cre, dim=2).squeeze(-1)
434
+ spec_f = torch.stack((outr, outi), dim=-1)
435
+ return assign_df(
436
+ spec[:, :, self.df_order - self.df_lookahead - 1],
437
+ spec_f.unsqueeze(1),
438
+ self.df_bins,
439
+ alpha,
440
+ )
441
+
442
+ def forward_real_hidden_state_loop(self, spec: Tensor, coefs: Tensor, alpha: Tensor) -> Tensor:
443
+ # Version5: Designed for onnx export. `spec` buffer handling is done via a torch buffer.
444
+
445
+ # spec (real) [B, 1, T, F', 2]
446
+ # coefs (real) [B, T, O, F, 2]
447
+ b, _, t, _, _ = spec.shape
448
+ spec_out = torch.empty((b, 1, t, self.freq_bins, 2), device=spec.device)
449
+ for t in range(spec.shape[2]):
450
+ self.spec_buf = self.spec_buf.roll(-1, dims=2)
451
+ self.spec_buf[:, :, -1] = spec[:, :, t]
452
+ sre, sim = self.spec_buf[..., : self.df_bins, :].split(1, -1)
453
+ cre, cim = coefs[:, t : t + 1].split(1, -1)
454
+ outr = torch.sum(sre * cre - sim * cim, dim=2).squeeze(-1)
455
+ outi = torch.sum(sre * cim + sim * cre, dim=2).squeeze(-1)
456
+ spec_f = torch.stack((outr, outi), dim=-1)
457
+ spec_out[:, :, t] = assign_df(
458
+ self.spec_buf[:, :, self.df_order - self.df_lookahead - 1].unsqueeze(2),
459
+ spec_f.unsqueeze(1),
460
+ self.df_bins,
461
+ alpha[:, t],
462
+ ).squeeze(2)
463
+ return spec_out
464
+
465
+
466
+ def assign_df(spec: Tensor, spec_f: Tensor, df_bins: int, alpha: Optional[Tensor]):
467
+ spec_out = spec.clone()
468
+ if alpha is not None:
469
+ b = spec.shape[0]
470
+ alpha = alpha.view(b, 1, -1, 1, 1)
471
+ spec_out[..., :df_bins, :] = spec_f * alpha + spec[..., :df_bins, :] * (1 - alpha)
472
+ else:
473
+ spec_out[..., :df_bins, :] = spec_f
474
+ return spec_out
475
+
476
+
477
+ def spec_pad(x: Tensor, window_size: int, lookahead: int, dim: int = 0) -> Tensor:
478
+ pad = [0] * x.dim() * 2
479
+ if dim >= 0:
480
+ pad[(x.dim() - dim - 1) * 2] = window_size - lookahead - 1
481
+ pad[(x.dim() - dim - 1) * 2 + 1] = lookahead
482
+ else:
483
+ pad[(-dim - 1) * 2] = window_size - lookahead - 1
484
+ pad[(-dim - 1) * 2 + 1] = lookahead
485
+ return F.pad(x, pad)
486
+
487
+
488
+ def as_strided(x: Tensor, window_size: int, lookahead: int, step: int = 1, dim: int = 0) -> Tensor:
489
+ shape = list(x.shape)
490
+ shape.insert(dim + 1, window_size)
491
+ x = spec_pad(x, window_size, lookahead, dim=dim)
492
+ # torch.fx workaround
493
+ step = 1
494
+ stride = [x.stride(0), x.stride(1), x.stride(2), x.stride(3)]
495
+ stride.insert(dim, stride[dim] * step)
496
+ return torch.as_strided(x, shape, stride)
497
+
498
+
499
+ class GroupedGRULayer(nn.Module):
500
+ input_size: Final[int]
501
+ hidden_size: Final[int]
502
+ out_size: Final[int]
503
+ bidirectional: Final[bool]
504
+ num_directions: Final[int]
505
+ groups: Final[int]
506
+ batch_first: Final[bool]
507
+
508
+ def __init__(
509
+ self,
510
+ input_size: int,
511
+ hidden_size: int,
512
+ groups: int,
513
+ batch_first: bool = True,
514
+ bias: bool = True,
515
+ dropout: float = 0,
516
+ bidirectional: bool = False,
517
+ ):
518
+ super().__init__()
519
+ assert input_size % groups == 0
520
+ assert hidden_size % groups == 0
521
+ kwargs = {
522
+ "bias": bias,
523
+ "batch_first": batch_first,
524
+ "dropout": dropout,
525
+ "bidirectional": bidirectional,
526
+ }
527
+ self.input_size = input_size // groups
528
+ self.hidden_size = hidden_size // groups
529
+ self.out_size = hidden_size
530
+ self.bidirectional = bidirectional
531
+ self.num_directions = 2 if bidirectional else 1
532
+ self.groups = groups
533
+ self.batch_first = batch_first
534
+ assert (self.hidden_size % groups) == 0, "Hidden size must be divisible by groups"
535
+ self.layers = nn.ModuleList(
536
+ (nn.GRU(self.input_size, self.hidden_size, **kwargs) for _ in range(groups))
537
+ )
538
+
539
+ def flatten_parameters(self):
540
+ for layer in self.layers:
541
+ layer.flatten_parameters()
542
+
543
+ def get_h0(self, batch_size: int = 1, device: torch.device = torch.device("cpu")):
544
+ return torch.zeros(
545
+ self.groups * self.num_directions,
546
+ batch_size,
547
+ self.hidden_size,
548
+ device=device,
549
+ )
550
+
551
+ def forward(self, input: Tensor, h0: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
552
+ # input shape: [B, T, I] if batch_first else [T, B, I], B: batch_size, I: input_size
553
+ # state shape: [G*D, B, H], where G: groups, D: num_directions, H: hidden_size
554
+
555
+ if h0 is None:
556
+ dim0, dim1 = input.shape[:2]
557
+ bs = dim0 if self.batch_first else dim1
558
+ h0 = self.get_h0(bs, device=input.device)
559
+ outputs: List[Tensor] = []
560
+ outstates: List[Tensor] = []
561
+ for i, layer in enumerate(self.layers):
562
+ o, s = layer(
563
+ input[..., i * self.input_size : (i + 1) * self.input_size],
564
+ h0[i * self.num_directions : (i + 1) * self.num_directions].detach(),
565
+ )
566
+ outputs.append(o)
567
+ outstates.append(s)
568
+ output = torch.cat(outputs, dim=-1)
569
+ h = torch.cat(outstates, dim=0)
570
+ return output, h
571
+
572
+
573
+ class GroupedGRU(nn.Module):
574
+ groups: Final[int]
575
+ num_layers: Final[int]
576
+ batch_first: Final[bool]
577
+ hidden_size: Final[int]
578
+ bidirectional: Final[bool]
579
+ num_directions: Final[int]
580
+ shuffle: Final[bool]
581
+ add_outputs: Final[bool]
582
+
583
+ def __init__(
584
+ self,
585
+ input_size: int,
586
+ hidden_size: int,
587
+ num_layers: int = 1,
588
+ groups: int = 4,
589
+ bias: bool = True,
590
+ batch_first: bool = True,
591
+ dropout: float = 0,
592
+ bidirectional: bool = False,
593
+ shuffle: bool = True,
594
+ add_outputs: bool = False,
595
+ ):
596
+ super().__init__()
597
+ kwargs = {
598
+ "groups": groups,
599
+ "bias": bias,
600
+ "batch_first": batch_first,
601
+ "dropout": dropout,
602
+ "bidirectional": bidirectional,
603
+ }
604
+ assert input_size % groups == 0
605
+ assert hidden_size % groups == 0
606
+ assert num_layers > 0
607
+ self.input_size = input_size
608
+ self.groups = groups
609
+ self.num_layers = num_layers
610
+ self.batch_first = batch_first
611
+ self.hidden_size = hidden_size // groups
612
+ self.bidirectional = bidirectional
613
+ self.num_directions = 2 if bidirectional else 1
614
+ if groups == 1:
615
+ shuffle = False # Fully connected, no need to shuffle
616
+ self.shuffle = shuffle
617
+ self.add_outputs = add_outputs
618
+ self.grus: List[GroupedGRULayer] = nn.ModuleList() # type: ignore
619
+ self.grus.append(GroupedGRULayer(input_size, hidden_size, **kwargs))
620
+ for _ in range(1, num_layers):
621
+ self.grus.append(GroupedGRULayer(hidden_size, hidden_size, **kwargs))
622
+ self.flatten_parameters()
623
+
624
+ def flatten_parameters(self):
625
+ for gru in self.grus:
626
+ gru.flatten_parameters()
627
+
628
+ def get_h0(self, batch_size: int, device: torch.device = torch.device("cpu")) -> Tensor:
629
+ return torch.zeros(
630
+ (self.num_layers * self.groups * self.num_directions, batch_size, self.hidden_size),
631
+ device=device,
632
+ )
633
+
634
+ def forward(self, input: Tensor, state: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
635
+ dim0, dim1, _ = input.shape
636
+ b = dim0 if self.batch_first else dim1
637
+ if state is None:
638
+ state = self.get_h0(b, input.device)
639
+ output = torch.zeros(
640
+ dim0, dim1, self.hidden_size * self.num_directions * self.groups, device=input.device
641
+ )
642
+ outstates = []
643
+ h = self.groups * self.num_directions
644
+ for i, gru in enumerate(self.grus):
645
+ input, s = gru(input, state[i * h : (i + 1) * h])
646
+ outstates.append(s)
647
+ if self.shuffle and i < self.num_layers - 1:
648
+ input = (
649
+ input.view(dim0, dim1, -1, self.groups).transpose(2, 3).reshape(dim0, dim1, -1)
650
+ )
651
+ if self.add_outputs:
652
+ output += input
653
+ else:
654
+ output = input
655
+ outstate = torch.cat(outstates, dim=0)
656
+ return output, outstate
657
+
658
+
659
+ class SqueezedGRU(nn.Module):
660
+ input_size: Final[int]
661
+ hidden_size: Final[int]
662
+
663
+ def __init__(
664
+ self,
665
+ input_size: int,
666
+ hidden_size: int,
667
+ output_size: Optional[int] = None,
668
+ num_layers: int = 1,
669
+ linear_groups: int = 8,
670
+ batch_first: bool = True,
671
+ gru_skip_op: Optional[Callable[..., torch.nn.Module]] = None,
672
+ linear_act_layer: Callable[..., torch.nn.Module] = nn.Identity,
673
+ ):
674
+ super().__init__()
675
+ self.input_size = input_size
676
+ self.hidden_size = hidden_size
677
+ self.linear_in = nn.Sequential(
678
+ GroupedLinearEinsum(input_size, hidden_size, linear_groups), linear_act_layer()
679
+ )
680
+ self.gru = nn.GRU(hidden_size, hidden_size, num_layers=num_layers, batch_first=batch_first)
681
+ self.gru_skip = gru_skip_op() if gru_skip_op is not None else None
682
+ if output_size is not None:
683
+ self.linear_out = nn.Sequential(
684
+ GroupedLinearEinsum(hidden_size, output_size, linear_groups), linear_act_layer()
685
+ )
686
+ else:
687
+ self.linear_out = nn.Identity()
688
+
689
+ def forward(self, input: Tensor, h=None) -> Tuple[Tensor, Tensor]:
690
+ input = self.linear_in(input)
691
+ x, h = self.gru(input, h)
692
+ if self.gru_skip is not None:
693
+ x = x + self.gru_skip(input)
694
+ x = self.linear_out(x)
695
+ return x, h
696
+
697
+
698
+ class GroupedLinearEinsum(nn.Module):
699
+ input_size: Final[int]
700
+ hidden_size: Final[int]
701
+ groups: Final[int]
702
+
703
+ def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
704
+ super().__init__()
705
+ # self.weight: Tensor
706
+ self.input_size = input_size
707
+ self.hidden_size = hidden_size
708
+ self.groups = groups
709
+ assert input_size % groups == 0
710
+ self.ws = input_size // groups
711
+ self.register_parameter(
712
+ "weight",
713
+ Parameter(
714
+ torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
715
+ ),
716
+ )
717
+ self.reset_parameters()
718
+
719
+ def reset_parameters(self):
720
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore
721
+
722
+ def forward(self, x: Tensor) -> Tensor:
723
+ # x: [..., I]
724
+ x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
725
+ x = torch.einsum("...gi,...gih->...gh", x, self.weight) # [..., G, H/G]
726
+ x = x.flatten(2, 3) # [B, T, H]
727
+ return x
728
+
729
+
730
+ class GroupedLinear(nn.Module):
731
+ input_size: Final[int]
732
+ hidden_size: Final[int]
733
+ groups: Final[int]
734
+ shuffle: Final[bool]
735
+
736
+ def __init__(self, input_size: int, hidden_size: int, groups: int = 1, shuffle: bool = True):
737
+ super().__init__()
738
+ assert input_size % groups == 0
739
+ assert hidden_size % groups == 0
740
+ self.groups = groups
741
+ self.input_size = input_size // groups
742
+ self.hidden_size = hidden_size // groups
743
+ if groups == 1:
744
+ shuffle = False
745
+ self.shuffle = shuffle
746
+ self.layers = nn.ModuleList(
747
+ nn.Linear(self.input_size, self.hidden_size) for _ in range(groups)
748
+ )
749
+
750
+ def forward(self, x: Tensor) -> Tensor:
751
+ outputs: List[Tensor] = []
752
+ for i, layer in enumerate(self.layers):
753
+ outputs.append(layer(x[..., i * self.input_size : (i + 1) * self.input_size]))
754
+ output = torch.cat(outputs, dim=-1)
755
+ if self.shuffle:
756
+ orig_shape = output.shape
757
+ output = (
758
+ output.view(-1, self.hidden_size, self.groups).transpose(-1, -2).reshape(orig_shape)
759
+ )
760
+ return output
761
+
762
+
763
+ class LocalSnrTarget(nn.Module):
764
+ def __init__(
765
+ self, ws: int = 20, db: bool = True, ws_ns: Optional[int] = None, target_snr_range=None
766
+ ):
767
+ super().__init__()
768
+ self.ws = self.calc_ws(ws)
769
+ self.ws_ns = self.ws * 2 if ws_ns is None else self.calc_ws(ws_ns)
770
+ self.db = db
771
+ self.range = target_snr_range
772
+
773
+ def calc_ws(self, ws_ms: int) -> int:
774
+ # Calculates windows size in stft domain given a window size in ms
775
+ p = ModelParams()
776
+ ws = ws_ms - p.fft_size / p.sr * 1000 # length ms of an fft_window
777
+ ws = 1 + ws / (p.hop_size / p.sr * 1000) # consider hop_size
778
+ return max(int(round(ws)), 1)
779
+
780
+ def forward(self, clean: Tensor, noise: Tensor, max_bin: Optional[int] = None) -> Tensor:
781
+ # clean: [B, 1, T, F]
782
+ # out: [B, T']
783
+ if max_bin is not None:
784
+ clean = as_complex(clean[..., :max_bin])
785
+ noise = as_complex(noise[..., :max_bin])
786
+ return (
787
+ local_snr(clean, noise, window_size=self.ws, db=self.db, window_size_ns=self.ws_ns)[0]
788
+ .clamp(self.range[0], self.range[1])
789
+ .squeeze(1)
790
+ )
791
+
792
+
793
+ def _local_energy(x: Tensor, ws: int, device: torch.device) -> Tensor:
794
+ if (ws % 2) == 0:
795
+ ws += 1
796
+ ws_half = ws // 2
797
+ x = F.pad(x.pow(2).sum(-1).sum(-1), (ws_half, ws_half, 0, 0))
798
+ w = torch.hann_window(ws, device=device, dtype=x.dtype)
799
+ x = x.unfold(-1, size=ws, step=1) * w
800
+ return torch.sum(x, dim=-1).div(ws)
801
+
802
+
803
+ def local_snr(
804
+ clean: Tensor,
805
+ noise: Tensor,
806
+ window_size: int,
807
+ db: bool = False,
808
+ window_size_ns: Optional[int] = None,
809
+ eps: float = 1e-12,
810
+ ) -> Tuple[Tensor, Tensor, Tensor]:
811
+ # clean shape: [B, C, T, F]
812
+ clean = as_real(clean)
813
+ noise = as_real(noise)
814
+ assert clean.dim() == 5
815
+
816
+ E_speech = _local_energy(clean, window_size, clean.device)
817
+ window_size_ns = window_size if window_size_ns is None else window_size_ns
818
+ E_noise = _local_energy(noise, window_size_ns, clean.device)
819
+
820
+ snr = E_speech / E_noise.clamp_min(eps)
821
+ if db:
822
+ snr = snr.clamp_min(eps).log10().mul(10)
823
+ return snr, E_speech, E_noise
824
+
825
+
826
+ def test_grouped_gru():
827
+ from icecream import ic
828
+
829
+ g = 2 # groups
830
+ h = 4 # hidden_size
831
+ i = 2 # input_size
832
+ b = 1 # batch_size
833
+ t = 5 # time_steps
834
+ m = GroupedGRULayer(i, h, g, batch_first=True)
835
+ ic(m)
836
+ input = torch.randn((b, t, i))
837
+ h0 = m.get_h0(b)
838
+ assert list(h0.shape) == [g, b, h // g]
839
+ out, hout = m(input, h0)
840
+
841
+ # Should be exportable as raw nn.Module
842
+ torch.onnx.export(
843
+ m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13
844
+ )
845
+ # Should be exportable as traced
846
+ m = torch.jit.trace(m, (input, h0))
847
+ torch.onnx.export(
848
+ m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13
849
+ )
850
+ # and as scripted module
851
+ m = torch.jit.script(m)
852
+ torch.onnx.export(
853
+ m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13
854
+ )
855
+
856
+ # now grouped gru
857
+ num = 2
858
+ m = GroupedGRU(i, h, num, g, batch_first=True, shuffle=True)
859
+ ic(m)
860
+ h0 = m.get_h0(b)
861
+ assert list(h0.shape) == [num * g, b, h // g]
862
+ out, hout = m(input, h0)
863
+
864
+ # Should be exportable as traced
865
+ m = torch.jit.trace(m, (input, h0))
866
+ torch.onnx.export(
867
+ m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13
868
+ )
869
+ # and scripted module
870
+ m = torch.jit.script(m)
871
+ torch.onnx.export(
872
+ m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13
873
+ )
874
+
875
+
876
+ def test_erb():
877
+ import libdf
878
+ from df.config import config
879
+
880
+ config.use_defaults()
881
+ p = ModelParams()
882
+ n_freq = p.fft_size // 2 + 1
883
+ df_state = libdf.DF(sr=p.sr, fft_size=p.fft_size, hop_size=p.hop_size, nb_bands=p.nb_erb)
884
+ erb = erb_fb(df_state.erb_widths(), p.sr)
885
+ erb_inverse = erb_fb(df_state.erb_widths(), p.sr, inverse=True)
886
+ input = torch.randn((1, 1, 1, n_freq), dtype=torch.complex64)
887
+ input_abs = input.abs().square()
888
+ erb_widths = df_state.erb_widths()
889
+ df_erb = torch.from_numpy(libdf.erb(input.numpy(), erb_widths, False))
890
+ py_erb = torch.matmul(input_abs, erb)
891
+ assert torch.allclose(df_erb, py_erb)
892
+ df_out = torch.from_numpy(libdf.erb_inv(df_erb.numpy(), erb_widths))
893
+ py_out = torch.matmul(py_erb, erb_inverse)
894
+ assert torch.allclose(df_out, py_out)
895
+
896
+
897
+ def test_unit_norm():
898
+ from df.config import config
899
+ from libdf import unit_norm
900
+
901
+ config.use_defaults()
902
+ p = ModelParams()
903
+ b = 2
904
+ F = p.nb_df
905
+ t = 100
906
+ spec = torch.randn(b, 1, t, F, 2)
907
+ alpha = get_norm_alpha(log=False)
908
+ # Expects complex input of shape [C, T, F]
909
+ norm_lib = torch.as_tensor(unit_norm(torch.view_as_complex(spec).squeeze(1).numpy(), alpha))
910
+ m = ExponentialUnitNorm(alpha, F)
911
+ norm_torch = torch.view_as_complex(m(spec).squeeze(1))
912
+ assert torch.allclose(norm_lib.real, norm_torch.real)
913
+ assert torch.allclose(norm_lib.imag, norm_torch.imag)
914
+ assert torch.allclose(norm_lib.abs(), norm_torch.abs())
915
+
916
+
917
+ def test_dfop():
918
+ from df.config import config
919
+
920
+ config.use_defaults()
921
+ p = ModelParams()
922
+ f = p.nb_df
923
+ F = f * 2
924
+ o = p.df_order
925
+ d = p.df_lookahead
926
+ t = 100
927
+ spec = torch.randn(1, 1, t, F, 2)
928
+ coefs = torch.randn(1, t, o, f, 2)
929
+ alpha = torch.randn(1, t, 1)
930
+ dfop = DfOp(df_bins=p.nb_df)
931
+ dfop.set_forward("real_loop")
932
+ out1 = dfop(spec, coefs, alpha)
933
+ dfop.set_forward("real_strided")
934
+ out2 = dfop(spec, coefs, alpha)
935
+ dfop.set_forward("real_unfold")
936
+ out3 = dfop(spec, coefs, alpha)
937
+ dfop.set_forward("complex_strided")
938
+ out4 = dfop(spec, coefs, alpha)
939
+ torch.testing.assert_allclose(out1, out2)
940
+ torch.testing.assert_allclose(out1, out3)
941
+ torch.testing.assert_allclose(out1, out4)
942
+ # This forward method requires external padding/lookahead as well as spectrogram buffer
943
+ # handling, i.e. via a ring buffer. Could be used in real time usage.
944
+ dfop.set_forward("real_one_step")
945
+ spec_padded = spec_pad(spec, o, d, dim=-3)
946
+ out5 = torch.zeros_like(out1)
947
+ for i in range(t):
948
+ out5[:, :, i] = dfop(
949
+ spec_padded[:, :, i : i + o], coefs[:, i].unsqueeze(1), alpha[:, i].unsqueeze(1)
950
+ )
951
+ torch.testing.assert_allclose(out1, out5)
952
+ # Forward method that does the padding/lookahead handling using an internal hidden state.
953
+ dfop.freq_bins = F
954
+ dfop.set_forward("real_hidden_state_loop")
955
+ out6 = dfop(spec, coefs, alpha)
956
+ torch.testing.assert_allclose(out1, out6)
df/multiframe.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Final
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import Tensor, nn
7
+
8
+
9
+ class MultiFrameModule(nn.Module, ABC):
10
+ """Multi-frame speech enhancement modules.
11
+
12
+ Signal model and notation:
13
+ Noisy: `x = s + n`
14
+ Enhanced: `y = f(x)`
15
+ Objective: `min ||s - y||`
16
+
17
+ PSD: Power spectral density, notated eg. as `Rxx` for noisy PSD.
18
+ IFC: Inter-frame correlation vector: PSD*u, u: selection vector. Notated as `rxx`
19
+ """
20
+
21
+ num_freqs: Final[int]
22
+ frame_size: Final[int]
23
+ need_unfold: Final[bool]
24
+
25
+ def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0):
26
+ """Multi-Frame filtering module.
27
+
28
+ Args:
29
+ num_freqs (int): Number of frequency bins used for filtering.
30
+ frame_size (int): Frame size in FD domain.
31
+ lookahead (int): Lookahead, may be used to select the output time step. Note: This
32
+ module does not add additional padding according to lookahead!
33
+ """
34
+ super().__init__()
35
+ self.num_freqs = num_freqs
36
+ self.frame_size = frame_size
37
+ self.pad = nn.ConstantPad2d((0, 0, frame_size - 1, 0), 0.0)
38
+ self.need_unfold = frame_size > 1
39
+ self.lookahead = lookahead
40
+
41
+ def spec_unfold(self, spec: Tensor):
42
+ """Pads and unfolds the spectrogram according to frame_size.
43
+
44
+ Args:
45
+ spec (complex Tensor): Spectrogram of shape [B, C, T, F]
46
+ Returns:
47
+ spec (Tensor): Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size.
48
+ """
49
+ if self.need_unfold:
50
+ return self.pad(spec).unfold(2, self.frame_size, 1)
51
+ return spec.unsqueeze(-1)
52
+
53
+ def forward(self, spec: Tensor, coefs: Tensor):
54
+ """Pads and unfolds the spectrogram and forwards to impl.
55
+
56
+ Args:
57
+ spec (Tensor): Spectrogram of shape [B, C, T, F, 2]
58
+ coefs (Tensor): Spectrogram of shape [B, C, T, F, 2]
59
+ """
60
+ spec_u = self.spec_unfold(torch.view_as_complex(spec))
61
+ coefs = torch.view_as_complex(coefs)
62
+ spec_f = spec_u.narrow(-2, 0, self.num_freqs)
63
+ spec_f = self.forward_impl(spec_f, coefs)
64
+ if self.training:
65
+ spec = spec.clone()
66
+ spec[..., : self.num_freqs, :] = torch.view_as_real(spec_f)
67
+ return spec
68
+
69
+ @abstractmethod
70
+ def forward_impl(self, spec: Tensor, coefs: Tensor) -> Tensor:
71
+ """Forward impl taking complex spectrogram and coefficients.
72
+
73
+ Args:
74
+ spec (complex Tensor): Spectrogram of shape [B, C1, T, F, N]
75
+ coefs (complex Tensor): Coefficients [B, C2, T, F]
76
+
77
+ Returns:
78
+ spec (complex Tensor): Enhanced spectrogram of shape [B, C1, T, F]
79
+ """
80
+ ...
81
+
82
+ @abstractmethod
83
+ def num_channels(self) -> int:
84
+ """Return the number of required channels.
85
+
86
+ If multiple inputs are required, then all these should be combined in one Tensor containing
87
+ the summed channels.
88
+ """
89
+ ...
90
+
91
+
92
+ def psd(x: Tensor, n: int) -> Tensor:
93
+ """Compute the PSD correlation matrix Rxx for a spectrogram.
94
+
95
+ That is, `X*conj(X)`, where `*` is the outer product.
96
+
97
+ Args:
98
+ x (complex Tensor): Spectrogram of shape [B, C, T, F]. Will be unfolded with `n` steps over
99
+ the time axis.
100
+
101
+ Returns:
102
+ Rxx (complex Tensor): Correlation matrix of shape [B, C, T, F, N, N]
103
+ """
104
+ x = F.pad(x, (0, 0, n - 1, 0)).unfold(-2, n, 1)
105
+ return torch.einsum("...n,...m->...mn", x, x.conj())
106
+
107
+
108
+ def df(spec: Tensor, coefs: Tensor) -> Tensor:
109
+ """Deep filter implemenation using `torch.einsum`. Requires unfolded spectrogram.
110
+
111
+ Args:
112
+ spec (complex Tensor): Spectrogram of shape [B, C, T, F, N]
113
+ coefs (complex Tensor): Spectrogram of shape [B, C, N, T, F]
114
+
115
+ Returns:
116
+ spec (complex Tensor): Spectrogram of shape [B, C, T, F]
117
+ """
118
+ return torch.einsum("...tfn,...ntf->...tf", spec, coefs)
119
+
120
+
121
+ class CRM(MultiFrameModule):
122
+ """Complex ratio mask."""
123
+
124
+ def __init__(self, num_freqs: int, frame_size: int = 1, lookahead: int = 0):
125
+ assert frame_size == 1 and lookahead == 0, (frame_size, lookahead)
126
+ super().__init__(num_freqs, 1)
127
+
128
+ def forward_impl(self, spec: Tensor, coefs: Tensor):
129
+ return spec.squeeze(-1).mul(coefs)
130
+
131
+ def num_channels(self):
132
+ return 2
133
+
134
+
135
+ class DF(MultiFrameModule):
136
+ conj: Final[bool]
137
+ """Deep Filtering."""
138
+
139
+ def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, conj: bool = False):
140
+ super().__init__(num_freqs, frame_size, lookahead)
141
+ self.conj = conj
142
+
143
+ def forward_impl(self, spec: Tensor, coefs: Tensor):
144
+ coefs = coefs.view(coefs.shape[0], -1, self.frame_size, *coefs.shape[2:])
145
+ if self.conj:
146
+ coefs = coefs.conj()
147
+ return df(spec, coefs)
148
+
149
+ def num_channels(self):
150
+ return self.frame_size * 2
151
+
152
+
153
+ class MfWf(MultiFrameModule):
154
+ """Multi-frame Wiener filter base module."""
155
+
156
+ def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0):
157
+ """Multi-frame Wiener Filter.
158
+
159
+ Several implementation methods are available resulting in different number of required input
160
+ coefficient channels.
161
+
162
+ Methods:
163
+ psd_ifc: Predict PSD `Rxx` and IFC `rss`.
164
+ df: Use deep filtering to predict speech and noisy spectrograms. These will be used for
165
+ PSD calculation for Wiener filtering. Alias: `df_sx`
166
+ c: Directly predict Wiener filter coefficients. Computation same as deep filtering.
167
+
168
+ """
169
+ super().__init__(num_freqs, frame_size, lookahead=0)
170
+ self.idx = -lookahead
171
+
172
+ def num_channels(self):
173
+ return self.num_channels
174
+
175
+ @staticmethod
176
+ def solve(Rxx, rss, diag_eps: float = 1e-8, eps: float = 1e-7) -> Tensor:
177
+ return torch.einsum(
178
+ "...nm,...m->...n", torch.inverse(_tik_reg(Rxx, diag_eps, eps)), rss
179
+ ) # [T, F, N]
180
+
181
+ @abstractmethod
182
+ def mfwf(self, spec: Tensor, coefs: Tensor) -> Tensor:
183
+ """Multi-frame Wiener filter impl taking complex spectrogram and coefficients.
184
+
185
+ Coefficients may be split into multiple parts w.g. for multiple DF coefs or PSDs.
186
+
187
+ Args:
188
+ spec (complex Tensor): Spectrogram of shape [B, C1, T, F, N]
189
+ coefs (complex Tensor): Coefficients [B, C2, T, F]
190
+
191
+ Returns:
192
+ c (complex Tensor): MfWf coefs of shape [B, C1, T, F, N]
193
+ """
194
+ ...
195
+
196
+ def forward_impl(self, spec: Tensor, coefs: Tensor) -> Tensor:
197
+ coefs = self.mfwf(spec, coefs)
198
+ return self.apply_coefs(spec, coefs)
199
+
200
+ @staticmethod
201
+ def apply_coefs(spec: Tensor, coefs: Tensor) -> Tensor:
202
+ # spec: [B, C, T, F, N]
203
+ # coefs: [B, C, T, F, N]
204
+ return torch.einsum("...n,...n->...", spec, coefs)
205
+
206
+
207
+ class MfWfDf(MfWf):
208
+ eps_diag: Final[float]
209
+
210
+ def __init__(
211
+ self,
212
+ num_freqs: int,
213
+ frame_size: int,
214
+ lookahead: int = 0,
215
+ eps_diag: float = 1e-7,
216
+ eps: float = 1e-7,
217
+ ):
218
+ super().__init__(num_freqs, frame_size, lookahead)
219
+ self.eps_diag = eps_diag
220
+ self.eps = eps
221
+
222
+ def num_channels(self):
223
+ # frame_size/df_order * 2 (x/s) * 2 (re/im)
224
+ return self.frame_size * 4
225
+
226
+ def mfwf(self, spec: Tensor, coefs: Tensor) -> Tensor:
227
+ coefs.chunk
228
+ df_s, df_x = torch.chunk(coefs, 2, 1) # [B, C, T, F, N]
229
+ df_s = df_s.unflatten(1, (-1, self.frame_size))
230
+ df_x = df_x.unflatten(1, (-1, self.frame_size))
231
+ spec_s = df(spec, df_s) # [B, C, T, F]
232
+ spec_x = df(spec, df_x)
233
+ Rss = psd(spec_s, self.frame_size) # [B, C, T, F, N. N]
234
+ Rxx = psd(spec_x, self.frame_size)
235
+ rss = Rss[..., -1] # TODO: use -1 or self.idx?
236
+ c = self.solve(Rxx, rss, self.eps_diag, self.eps) # [B, C, T, F, N]
237
+ return c
238
+
239
+
240
+ class MfWfPsd(MfWf):
241
+ """Multi-frame Wiener filter by predicting noisy PSD `Rxx` and speech IFC `rss`."""
242
+
243
+ def num_channels(self):
244
+ # (Rxx + rss) * 2 (re/im)
245
+ return (self.frame_size**2 + self.frame_size) * 2
246
+
247
+ def mfwf(self, spec: Tensor, coefs: Tensor) -> Tensor: # type: ignore
248
+ Rxx, rss = torch.split(coefs.movedim(1, -1), [self.frame_size**2, self.frame_size], -1)
249
+ c = self.solve(Rxx.unflatten(-1, (self.frame_size, self.frame_size)), rss)
250
+ return c
251
+
252
+
253
+ class MfWfC(MfWf):
254
+ """Multi-frame Wiener filter by directly predicting the MfWf coefficients."""
255
+
256
+ def num_channels(self):
257
+ # mfwf coefs * 2 (re/im)
258
+ return self.frame_size * 2
259
+
260
+ def mfwf(self, spec: Tensor, coefs: Tensor) -> Tensor: # type: ignore
261
+ coefs = coefs.unflatten(1, (-1, self.frame_size)).permute(
262
+ 0, 1, 3, 4, 2
263
+ ) # [B, C*N, T, F] -> [B, C, T, F, N]
264
+ return coefs
265
+
266
+
267
+ class MvdrSouden(MultiFrameModule):
268
+ def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0):
269
+ super().__init__(num_freqs, frame_size, lookahead)
270
+
271
+
272
+ class MvdrEvd(MultiFrameModule):
273
+ def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0):
274
+ super().__init__(num_freqs, frame_size, lookahead)
275
+
276
+
277
+ class MvdrRtfPower(MultiFrameModule):
278
+ def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0):
279
+ super().__init__(num_freqs, frame_size, lookahead)
280
+
281
+
282
+ MF_METHODS: Dict[str, MultiFrameModule] = {
283
+ "crm": CRM,
284
+ "df": DF,
285
+ "mfwf_df": MfWfDf,
286
+ "mfwf_df_sx": MfWfDf,
287
+ "mfwf_psd": MfWfPsd,
288
+ "mfwf_psd_ifc": MfWfPsd,
289
+ "mfwf_c": MfWfC,
290
+ }
291
+
292
+
293
+ # From torchaudio
294
+ def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor:
295
+ r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
296
+ Args:
297
+ input (torch.Tensor): Tensor of dimension `(..., channel, channel)`
298
+ dim1 (int, optional): the first dimension of the diagonal matrix
299
+ (Default: -1)
300
+ dim2 (int, optional): the second dimension of the diagonal matrix
301
+ (Default: -2)
302
+ Returns:
303
+ Tensor: trace of the input Tensor
304
+ """
305
+ assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
306
+ assert (
307
+ input.shape[dim1] == input.shape[dim2]
308
+ ), "The size of ``dim1`` and ``dim2`` must be the same."
309
+ input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
310
+ return input.sum(dim=-1)
311
+
312
+
313
+ def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor:
314
+ """Perform Tikhonov regularization (only modifying real part).
315
+ Args:
316
+ mat (torch.Tensor): input matrix (..., channel, channel)
317
+ reg (float, optional): regularization factor (Default: 1e-8)
318
+ eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: ``1e-8``)
319
+ Returns:
320
+ Tensor: regularized matrix (..., channel, channel)
321
+ """
322
+ # Add eps
323
+ C = mat.size(-1)
324
+ eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
325
+ epsilon = _compute_mat_trace(mat).real[..., None, None] * reg
326
+ # in case that correlation_matrix is all-zero
327
+ epsilon = epsilon + eps
328
+ mat = mat + epsilon * eye[..., :, :]
329
+ return mat
df/utils.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import math
3
+ import os
4
+ import random
5
+ import subprocess
6
+ from socket import gethostname
7
+ from typing import Any, Dict, Set, Tuple, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from loguru import logger
12
+ from torch import Tensor
13
+ #from torch._six import string_classes
14
+ from torch.autograd import Function
15
+ from torch.types import Number
16
+
17
+ from df.config import config
18
+ from df.model import ModelParams
19
+
20
+ try:
21
+ from torchaudio.functional import resample as ta_resample
22
+ except ImportError:
23
+ from torchaudio.compliance.kaldi import resample_waveform as ta_resample # type: ignore
24
+
25
+
26
+ def get_resample_params(method: str) -> Dict[str, Any]:
27
+ params = {
28
+ "sinc_fast": {"resampling_method": "sinc_interpolation", "lowpass_filter_width": 16},
29
+ "sinc_best": {"resampling_method": "sinc_interpolation", "lowpass_filter_width": 64},
30
+ "kaiser_fast": {
31
+ "resampling_method": "kaiser_window",
32
+ "lowpass_filter_width": 16,
33
+ "rolloff": 0.85,
34
+ "beta": 8.555504641634386,
35
+ },
36
+ "kaiser_best": {
37
+ "resampling_method": "kaiser_window",
38
+ "lowpass_filter_width": 16,
39
+ "rolloff": 0.9475937167399596,
40
+ "beta": 14.769656459379492,
41
+ },
42
+ }
43
+ assert method in params.keys(), f"method must be one of {list(params.keys())}"
44
+ return params[method]
45
+
46
+
47
+ def resample(audio: Tensor, orig_sr: int, new_sr: int, method="sinc_fast"):
48
+ params = get_resample_params(method)
49
+ return ta_resample(audio, orig_sr, new_sr, **params)
50
+
51
+
52
+ def get_device():
53
+ s = config("DEVICE", default="", section="train")
54
+ if s == "":
55
+ if torch.cuda.is_available():
56
+ DEVICE = torch.device("cuda:0")
57
+ else:
58
+ DEVICE = torch.device("cpu")
59
+ else:
60
+ DEVICE = torch.device(s)
61
+ return DEVICE
62
+
63
+
64
+ def as_complex(x: Tensor):
65
+ if torch.is_complex(x):
66
+ return x
67
+ if x.shape[-1] != 2:
68
+ raise ValueError(f"Last dimension need to be of length 2 (re + im), but got {x.shape}")
69
+ if x.stride(-1) != 1:
70
+ x = x.contiguous()
71
+ return torch.view_as_complex(x)
72
+
73
+
74
+ def as_real(x: Tensor):
75
+ if torch.is_complex(x):
76
+ return torch.view_as_real(x)
77
+ return x
78
+
79
+
80
+ class angle_re_im(Function):
81
+ """Similar to torch.angle but robustify the gradient for zero magnitude."""
82
+
83
+ @staticmethod
84
+ def forward(ctx, re: Tensor, im: Tensor):
85
+ ctx.save_for_backward(re, im)
86
+ return torch.atan2(im, re)
87
+
88
+ @staticmethod
89
+ def backward(ctx, grad: Tensor) -> Tuple[Tensor, Tensor]:
90
+ re, im = ctx.saved_tensors
91
+ grad_inv = grad / (re.square() + im.square()).clamp_min_(1e-10)
92
+ return -im * grad_inv, re * grad_inv
93
+
94
+
95
+ class angle(Function):
96
+ """Similar to torch.angle but robustify the gradient for zero magnitude."""
97
+
98
+ @staticmethod
99
+ def forward(ctx, x: Tensor):
100
+ ctx.save_for_backward(x)
101
+ return torch.atan2(x.imag, x.real)
102
+
103
+ @staticmethod
104
+ def backward(ctx, grad: Tensor):
105
+ (x,) = ctx.saved_tensors
106
+ grad_inv = grad / (x.real.square() + x.imag.square()).clamp_min_(1e-10)
107
+ return torch.view_as_complex(torch.stack((-x.imag * grad_inv, x.real * grad_inv), dim=-1))
108
+
109
+
110
+ def check_finite_module(obj, name="Module", _raise=True) -> Set[str]:
111
+ out: Set[str] = set()
112
+ if isinstance(obj, torch.nn.Module):
113
+ for name, child in obj.named_children():
114
+ out = out | check_finite_module(child, name)
115
+ for name, param in obj.named_parameters():
116
+ out = out | check_finite_module(param, name)
117
+ for name, buf in obj.named_buffers():
118
+ out = out | check_finite_module(buf, name)
119
+ if _raise and len(out) > 0:
120
+ raise ValueError(f"{name} not finite during checkpoint writing including: {out}")
121
+ return out
122
+
123
+
124
+ def make_np(x: Union[Tensor, np.ndarray, Number]) -> np.ndarray:
125
+ """Transforms Tensor to numpy.
126
+ Args:
127
+ x: An instance of torch tensor or caffe blob name
128
+
129
+ Returns:
130
+ numpy.array: Numpy array
131
+ """
132
+ if isinstance(x, np.ndarray):
133
+ return x
134
+ if np.isscalar(x):
135
+ return np.array([x])
136
+ if isinstance(x, Tensor):
137
+ return x.detach().cpu().numpy()
138
+ raise NotImplementedError(
139
+ "Got {}, but numpy array, scalar, or torch tensor are expected.".format(type(x))
140
+ )
141
+
142
+
143
+ def get_norm_alpha(log: bool = True) -> float:
144
+ p = ModelParams()
145
+ a_ = _calculate_norm_alpha(sr=p.sr, hop_size=p.hop_size, tau=p.norm_tau)
146
+ precision = 3
147
+ a = 1.0
148
+ while a >= 1.0:
149
+ a = round(a_, precision)
150
+ precision += 1
151
+ if log:
152
+ logger.info(f"Running with normalization window alpha = '{a}'")
153
+ return a
154
+
155
+
156
+ def _calculate_norm_alpha(sr: int, hop_size: int, tau: float):
157
+ """Exponential decay factor alpha for a given tau (decay window size [s])."""
158
+ dt = hop_size / sr
159
+ return math.exp(-dt / tau)
160
+
161
+
162
+ def check_manual_seed(seed: int = None):
163
+ """If manual seed is not specified, choose a random one and communicate it to the user."""
164
+ seed = seed or random.randint(1, 10000)
165
+ np.random.seed(seed)
166
+ random.seed(seed)
167
+ torch.manual_seed(seed)
168
+ return seed
169
+
170
+
171
+ def get_git_root():
172
+ git_local_dir = os.path.dirname(os.path.abspath(__file__))
173
+ args = ["git", "-C", git_local_dir, "rev-parse", "--show-toplevel"]
174
+ return subprocess.check_output(args).strip().decode()
175
+
176
+
177
+ def get_commit_hash():
178
+ """Returns the current git commit."""
179
+ try:
180
+ git_dir = get_git_root()
181
+ args = ["git", "-C", git_dir, "rev-parse", "--short", "--verify", "HEAD"]
182
+ commit = subprocess.check_output(args).strip().decode()
183
+ except subprocess.CalledProcessError:
184
+ # probably not in git repo
185
+ commit = None
186
+ return commit
187
+
188
+
189
+ def get_host() -> str:
190
+ return gethostname()
191
+
192
+
193
+ def get_branch_name():
194
+ try:
195
+ git_dir = os.path.dirname(os.path.abspath(__file__))
196
+ args = ["git", "-C", git_dir, "rev-parse", "--abbrev-ref", "HEAD"]
197
+ branch = subprocess.check_output(args).strip().decode()
198
+ except subprocess.CalledProcessError:
199
+ # probably not in git repo
200
+ branch = None
201
+ return branch
202
+
203
+
204
+ # from pytorch/ignite:
205
+ def apply_to_tensor(input_, func):
206
+ """Apply a function on a tensor or mapping, or sequence of tensors."""
207
+ if isinstance(input_, torch.nn.Module):
208
+ return [apply_to_tensor(c, func) for c in input_.children()]
209
+ elif isinstance(input_, torch.nn.Parameter):
210
+ return func(input_.data)
211
+ elif isinstance(input_, Tensor):
212
+ return func(input_)
213
+ elif isinstance(input_, str):
214
+ return input_
215
+ elif isinstance(input_, collections.Mapping):
216
+ return {k: apply_to_tensor(sample, func) for k, sample in input_.items()}
217
+ elif isinstance(input_, collections.Iterable):
218
+ return [apply_to_tensor(sample, func) for sample in input_]
219
+ elif input_ is None:
220
+ return input_
221
+ else:
222
+ return input_
223
+
224
+
225
+ def detach_hidden(hidden: Any) -> Any:
226
+ """Cut backpropagation graph.
227
+ Auxillary function to cut the backpropagation graph by detaching the hidden
228
+ vector.
229
+ """
230
+ return apply_to_tensor(hidden, Tensor.detach)
libdf/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .libdf import *
2
+
3
+ __doc__ = libdf.__doc__
libdf/__init__.pyi ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ from numpy import ndarray
4
+
5
+ class DF:
6
+ def __init__(
7
+ self,
8
+ sr: int,
9
+ fft_size: int,
10
+ hop_size: int,
11
+ nb_bands: int,
12
+ min_nb_erb_freqs: Optional[int] = 1,
13
+ ):
14
+ """DeepFilter state used for analysis and synthesis.
15
+
16
+ Args:
17
+ sr (int): Sampling rate.
18
+ fft_size (int): Window length used for the Fast Fourier transform.
19
+ hop_size (int): Hop size between two analysis windows. Also called frame size.
20
+ nb_bands (int): Number of ERB bands.
21
+ min_nb_erb_freqs (int): Minimum number of frequency bands per ERB band. Defaults to 1.
22
+ """
23
+ ...
24
+ def analysis(self, input: ndarray) -> ndarray:
25
+ """Analysis of a time-domain signal.
26
+
27
+ Args:
28
+ input (ndarray): 2D real-valued array of shape [C, T].
29
+ Output:
30
+ output (ndarray): 3D complex-valued array of shape [C, T', F], where F is the `fft_size`,
31
+ and T' the original time T divided by `hop_size`.
32
+ """
33
+ ...
34
+ def synthesis(self, input: ndarray) -> ndarray:
35
+ """Synthesis of a frequency-domain signal.
36
+
37
+ Args:
38
+ input (ndarray): 3D complex-valued array of shape [C, T, F].
39
+ Output:
40
+ output (ndarray): 2D real-valued array of shape [C, T].
41
+ """
42
+ ...
43
+ def erb_widths(self) -> ndarray: ...
44
+ def fft_window(self) -> ndarray: ...
45
+ def sr(self) -> int: ...
46
+ def fft_size(self) -> int: ...
47
+ def hop_size(self) -> int: ...
48
+ def nb_erb(self) -> int: ...
49
+ def reset(self) -> None: ...
50
+
51
+ def erb(
52
+ input: ndarray, erb_fb: Union[ndarray, List[int]], db: Optional[bool] = None
53
+ ) -> ndarray: ...
54
+ def erb_inv(input: ndarray, erb_fb: Union[ndarray, List[int]]) -> ndarray: ...
55
+ def erb_norm(erb: ndarray, alpha: float, state: Optional[ndarray] = None) -> ndarray: ...
56
+ def unit_norm(spec: ndarray, alpha: float, state: Optional[ndarray] = None) -> ndarray: ...
57
+ def unit_norm_init(num_freq_bins: int) -> ndarray: ...
libdf/py.typed ADDED
File without changes
model_weights/voice_enhance/checkpoints/model_96.ckpt.best ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb5eccb429e675bb4ec5ec9e280f048bfff9787b40bd3eb835fd11509eb14a3e
3
+ size 9397209
model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc7ff82ef75becd495aab2ede3a8220da393a717f178ae9534df355a6173bbca
3
+ size 17090379
model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7dfd48d0da24db35ee4a653d0d36a4104cb26873050a5c3584675eee21937621
3
+ size 69
model_weights/voiceover/freevc-24.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:872360b61e6bbe09bec29810e7ad0d16318e379f6195a7ff3b06e50efb08ad31
3
+ size 1264
model_weights/voiceover/freevc-24.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b39a86fefbc9ec6e30be8d26ee2a6aa5ffe6d235f6ab15773d01cdf348e5b20
3
+ size 472644351
model_weights/wavlm_models/WavLM-Large.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6fb4b3c3e6aa567f0a997b30855859cb81528ee8078802af439f7b2da0bf100f
3
+ size 1261965425
model_weights/wavlm_models/WavLM-Large.pt.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9836bca8ab0e9d0b4797aa78f41b367800d26cfd25ade7b1edcb35bc3c171e4
3
+ size 52
nnet/__init__.py ADDED
File without changes
nnet/attentions.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from nnet import commons
7
+ from nnet.modules import LayerNorm
8
+
9
+
10
+ class Encoder(nn.Module):
11
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
12
+ super().__init__()
13
+ self.hidden_channels = hidden_channels
14
+ self.filter_channels = filter_channels
15
+ self.n_heads = n_heads
16
+ self.n_layers = n_layers
17
+ self.kernel_size = kernel_size
18
+ self.p_dropout = p_dropout
19
+ self.window_size = window_size
20
+
21
+ self.drop = nn.Dropout(p_dropout)
22
+ self.attn_layers = nn.ModuleList()
23
+ self.norm_layers_1 = nn.ModuleList()
24
+ self.ffn_layers = nn.ModuleList()
25
+ self.norm_layers_2 = nn.ModuleList()
26
+ for i in range(self.n_layers):
27
+ self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
28
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
29
+ self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
30
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
31
+
32
+ def forward(self, x, x_mask):
33
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
34
+ x = x * x_mask
35
+ for i in range(self.n_layers):
36
+ y = self.attn_layers[i](x, x, attn_mask)
37
+ y = self.drop(y)
38
+ x = self.norm_layers_1[i](x + y)
39
+
40
+ y = self.ffn_layers[i](x, x_mask)
41
+ y = self.drop(y)
42
+ x = self.norm_layers_2[i](x + y)
43
+ x = x * x_mask
44
+ return x
45
+
46
+
47
+ class Decoder(nn.Module):
48
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
49
+ super().__init__()
50
+ self.hidden_channels = hidden_channels
51
+ self.filter_channels = filter_channels
52
+ self.n_heads = n_heads
53
+ self.n_layers = n_layers
54
+ self.kernel_size = kernel_size
55
+ self.p_dropout = p_dropout
56
+ self.proximal_bias = proximal_bias
57
+ self.proximal_init = proximal_init
58
+
59
+ self.drop = nn.Dropout(p_dropout)
60
+ self.self_attn_layers = nn.ModuleList()
61
+ self.norm_layers_0 = nn.ModuleList()
62
+ self.encdec_attn_layers = nn.ModuleList()
63
+ self.norm_layers_1 = nn.ModuleList()
64
+ self.ffn_layers = nn.ModuleList()
65
+ self.norm_layers_2 = nn.ModuleList()
66
+ for i in range(self.n_layers):
67
+ self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
68
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
69
+ self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
70
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
71
+ self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
72
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
73
+
74
+ def forward(self, x, x_mask, h, h_mask):
75
+ """
76
+ x: decoder input
77
+ h: encoder output
78
+ """
79
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
80
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
81
+ x = x * x_mask
82
+ for i in range(self.n_layers):
83
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
84
+ y = self.drop(y)
85
+ x = self.norm_layers_0[i](x + y)
86
+
87
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
88
+ y = self.drop(y)
89
+ x = self.norm_layers_1[i](x + y)
90
+
91
+ y = self.ffn_layers[i](x, x_mask)
92
+ y = self.drop(y)
93
+ x = self.norm_layers_2[i](x + y)
94
+ x = x * x_mask
95
+ return x
96
+
97
+
98
+ class MultiHeadAttention(nn.Module):
99
+ def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
100
+ super().__init__()
101
+ assert channels % n_heads == 0
102
+
103
+ self.channels = channels
104
+ self.out_channels = out_channels
105
+ self.n_heads = n_heads
106
+ self.p_dropout = p_dropout
107
+ self.window_size = window_size
108
+ self.heads_share = heads_share
109
+ self.block_length = block_length
110
+ self.proximal_bias = proximal_bias
111
+ self.proximal_init = proximal_init
112
+ self.attn = None
113
+
114
+ self.k_channels = channels // n_heads
115
+ self.conv_q = nn.Conv1d(channels, channels, 1)
116
+ self.conv_k = nn.Conv1d(channels, channels, 1)
117
+ self.conv_v = nn.Conv1d(channels, channels, 1)
118
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
119
+ self.drop = nn.Dropout(p_dropout)
120
+
121
+ if window_size is not None:
122
+ n_heads_rel = 1 if heads_share else n_heads
123
+ rel_stddev = self.k_channels**-0.5
124
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
125
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
126
+
127
+ nn.init.xavier_uniform_(self.conv_q.weight)
128
+ nn.init.xavier_uniform_(self.conv_k.weight)
129
+ nn.init.xavier_uniform_(self.conv_v.weight)
130
+ if proximal_init:
131
+ with torch.no_grad():
132
+ self.conv_k.weight.copy_(self.conv_q.weight)
133
+ self.conv_k.bias.copy_(self.conv_q.bias)
134
+
135
+ def forward(self, x, c, attn_mask=None):
136
+ q = self.conv_q(x)
137
+ k = self.conv_k(c)
138
+ v = self.conv_v(c)
139
+
140
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
141
+
142
+ x = self.conv_o(x)
143
+ return x
144
+
145
+ def attention(self, query, key, value, mask=None):
146
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
147
+ b, d, t_s, t_t = (*key.size(), query.size(2))
148
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
149
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
150
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
151
+
152
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
153
+ if self.window_size is not None:
154
+ assert t_s == t_t, "Relative attention is only available for self-attention."
155
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
156
+ rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
157
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
158
+ scores = scores + scores_local
159
+ if self.proximal_bias:
160
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
161
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
162
+ if mask is not None:
163
+ scores = scores.masked_fill(mask == 0, -1e4)
164
+ if self.block_length is not None:
165
+ assert t_s == t_t, "Local attention is only available for self-attention."
166
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
167
+ scores = scores.masked_fill(block_mask == 0, -1e4)
168
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
169
+ p_attn = self.drop(p_attn)
170
+ output = torch.matmul(p_attn, value)
171
+ if self.window_size is not None:
172
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
173
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
174
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
175
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
176
+ return output, p_attn
177
+
178
+ def _matmul_with_relative_values(self, x, y):
179
+ """
180
+ x: [b, h, l, m]
181
+ y: [h or 1, m, d]
182
+ ret: [b, h, l, d]
183
+ """
184
+ ret = torch.matmul(x, y.unsqueeze(0))
185
+ return ret
186
+
187
+ def _matmul_with_relative_keys(self, x, y):
188
+ """
189
+ x: [b, h, l, d]
190
+ y: [h or 1, m, d]
191
+ ret: [b, h, l, m]
192
+ """
193
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
194
+ return ret
195
+
196
+ def _get_relative_embeddings(self, relative_embeddings, length):
197
+ max_relative_position = 2 * self.window_size + 1
198
+ # Pad first before slice to avoid using cond ops.
199
+ pad_length = max(length - (self.window_size + 1), 0)
200
+ slice_start_position = max((self.window_size + 1) - length, 0)
201
+ slice_end_position = slice_start_position + 2 * length - 1
202
+ if pad_length > 0:
203
+ padded_relative_embeddings = F.pad(
204
+ relative_embeddings,
205
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
206
+ else:
207
+ padded_relative_embeddings = relative_embeddings
208
+ used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
209
+ return used_relative_embeddings
210
+
211
+ def _relative_position_to_absolute_position(self, x):
212
+ """
213
+ x: [b, h, l, 2*l-1]
214
+ ret: [b, h, l, l]
215
+ """
216
+ batch, heads, length, _ = x.size()
217
+ # Concat columns of pad to shift from relative to absolute indexing.
218
+ x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
219
+
220
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
221
+ x_flat = x.view([batch, heads, length * 2 * length])
222
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
223
+
224
+ # Reshape and slice out the padded elements.
225
+ x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
226
+ return x_final
227
+
228
+ def _absolute_position_to_relative_position(self, x):
229
+ """
230
+ x: [b, h, l, l]
231
+ ret: [b, h, l, 2*l-1]
232
+ """
233
+ batch, heads, length, _ = x.size()
234
+ # padd along column
235
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
236
+ x_flat = x.view([batch, heads, length**2 + length*(length -1)])
237
+ # add 0's in the beginning that will skew the elements after reshape
238
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
239
+ x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
240
+ return x_final
241
+
242
+ def _attention_bias_proximal(self, length):
243
+ """Bias for self-attention to encourage attention to close positions.
244
+ Args:
245
+ length: an integer scalar.
246
+ Returns:
247
+ a Tensor with shape [1, 1, length, length]
248
+ """
249
+ r = torch.arange(length, dtype=torch.float32)
250
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
251
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
252
+
253
+
254
+ class FFN(nn.Module):
255
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
256
+ super().__init__()
257
+ self.in_channels = in_channels
258
+ self.out_channels = out_channels
259
+ self.filter_channels = filter_channels
260
+ self.kernel_size = kernel_size
261
+ self.p_dropout = p_dropout
262
+ self.activation = activation
263
+ self.causal = causal
264
+
265
+ if causal:
266
+ self.padding = self._causal_padding
267
+ else:
268
+ self.padding = self._same_padding
269
+
270
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
271
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
272
+ self.drop = nn.Dropout(p_dropout)
273
+
274
+ def forward(self, x, x_mask):
275
+ x = self.conv_1(self.padding(x * x_mask))
276
+ if self.activation == "gelu":
277
+ x = x * torch.sigmoid(1.702 * x)
278
+ else:
279
+ x = torch.relu(x)
280
+ x = self.drop(x)
281
+ x = self.conv_2(self.padding(x * x_mask))
282
+ return x * x_mask
283
+
284
+ def _causal_padding(self, x):
285
+ if self.kernel_size == 1:
286
+ return x
287
+ pad_l = self.kernel_size - 1
288
+ pad_r = 0
289
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
290
+ x = F.pad(x, commons.convert_pad_shape(padding))
291
+ return x
292
+
293
+ def _same_padding(self, x):
294
+ if self.kernel_size == 1:
295
+ return x
296
+ pad_l = (self.kernel_size - 1) // 2
297
+ pad_r = self.kernel_size // 2
298
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
299
+ x = F.pad(x, commons.convert_pad_shape(padding))
300
+ return x