File size: 11,321 Bytes
fc5ed00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import os
import string
from configparser import ConfigParser
from shlex import shlex
from typing import Any, List, Optional, Tuple, Type, TypeVar, Union

from loguru import logger

T = TypeVar("T")


class DfParams:
    def __init__(self):
        # Sampling rate used for training
        self.sr: int = config("SR", cast=int, default=48_000, section="DF")
        # FFT size in samples
        self.fft_size: int = config("FFT_SIZE", cast=int, default=960, section="DF")
        # STFT Hop size in samples
        self.hop_size: int = config("HOP_SIZE", cast=int, default=480, section="DF")
        # Number of ERB bands
        self.nb_erb: int = config("NB_ERB", cast=int, default=32, section="DF")
        # Number of deep filtering bins; DF is applied from 0th to nb_df-th frequency bins
        self.nb_df: int = config("NB_DF", cast=int, default=96, section="DF")
        # Normalization decay factor; used for complex and erb features
        self.norm_tau: float = config("NORM_TAU", 1, float, section="DF")
        # Local SNR minimum value, ground truth will be truncated
        self.lsnr_max: int = config("LSNR_MAX", 35, int, section="DF")
        # Local SNR maximum value, ground truth will be truncated
        self.lsnr_min: int = config("LSNR_MIN", -15, int, section="DF")
        # Minimum number of frequency bins per ERB band
        self.min_nb_freqs = config("MIN_NB_ERB_FREQS", 2, int, section="DF")
        # Deep Filtering order
        self.df_order: int = config("DF_ORDER", cast=int, default=5, section="DF")
        # Deep Filtering look-ahead
        self.df_lookahead: int = config("DF_LOOKAHEAD", cast=int, default=0, section="DF")
        # Pad mode. By default, padding will be handled on the input side:
        # - `input`, which pads the input features passed to the model
        # - `output`, which pads the output spectrogram corresponding to `df_lookahead`
        self.pad_mode: str = config("PAD_MODE", default="input_specf", section="DF")


class Config:
    """Adopted from python-decouple"""

    DEFAULT_SECTION = "settings"

    def __init__(self):
        self.parser: ConfigParser = None  # type: ignore
        self.path = ""
        self.modified = False
        self.allow_defaults = True

    def load(
        self, path: Optional[str], config_must_exist=False, allow_defaults=True, allow_reload=False
    ):
        self.allow_defaults = allow_defaults
        if self.parser is not None and not allow_reload:
            raise ValueError("Config already loaded")
        self.parser = ConfigParser()
        self.path = path
        if path is not None and os.path.isfile(path):
            with open(path) as f:
                self.parser.read_file(f)
        else:
            if config_must_exist:
                raise ValueError(f"No config file found at '{path}'.")
        if not self.parser.has_section(self.DEFAULT_SECTION):
            self.parser.add_section(self.DEFAULT_SECTION)
        self._fix_clc()
        self._fix_df()

    def use_defaults(self):
        self.load(path=None, config_must_exist=False)

    def save(self, path: str):
        if not self.modified:
            logger.debug("Config not modified. No need to overwrite on disk.")
            return
        if self.parser is None:
            self.parser = ConfigParser()
        for section in self.parser.sections():
            if len(self.parser[section]) == 0:
                self.parser.remove_section(section)
        with open(path, mode="w") as f:
            self.parser.write(f)

    def tostr(self, value, cast):
        if isinstance(cast, Csv) and isinstance(value, (tuple, list)):
            return "".join(str(v) + cast.delimiter for v in value)[:-1]
        return str(value)

    def set(self, option: str, value: T, cast: Type[T], section: Optional[str] = None) -> T:
        section = self.DEFAULT_SECTION if section is None else section
        section = section.lower()
        if not self.parser.has_section(section):
            self.parser.add_section(section)
        if self.parser.has_option(section, option):
            if value == self.cast(self.parser.get(section, option), cast):
                return value
        self.modified = True
        self.parser.set(section, option, self.tostr(value, cast))
        return value

    def __call__(
        self,
        option: str,
        default: Any = None,
        cast: Type[T] = str,
        save: bool = True,
        section: Optional[str] = None,
    ) -> T:
        # Get value either from an ENV or from the .ini file
        section = self.DEFAULT_SECTION if section is None else section
        value = None
        if self.parser is None:
            raise ValueError("No configuration loaded")
        if not self.parser.has_section(section.lower()):
            self.parser.add_section(section.lower())
        if option in os.environ:
            value = os.environ[option]
            if save:
                self.parser.set(section, option, self.tostr(value, cast))
        elif self.parser.has_option(section, option):
            value = self.read_from_section(section, option, default, cast=cast, save=save)
        elif self.parser.has_option(section.lower(), option):
            value = self.read_from_section(section.lower(), option, default, cast=cast, save=save)
        elif self.parser.has_option(self.DEFAULT_SECTION, option):
            logger.warning(
                f"Couldn't find option {option} in section {section}. "
                "Falling back to default settings section."
            )
            value = self.read_from_section(self.DEFAULT_SECTION, option, cast=cast, save=save)
        elif default is None:
            raise ValueError("Value {} not found.".format(option))
        elif not self.allow_defaults and save:
            raise ValueError(f"Value '{option}' not found in config (defaults not allowed).")
        else:
            value = default
            if save:
                self.set(option, value, cast, section)
        return self.cast(value, cast)

    def cast(self, value, cast):
        # Do the casting to get the correct type
        if cast is bool:
            value = str(value).lower()
            if value in {"true", "yes", "y", "on", "1"}:
                return True  # type: ignore
            elif value in {"false", "no", "n", "off", "0"}:
                return False  # type: ignore
            raise ValueError("Parse error")
        return cast(value)

    def get(self, option: str, cast: Type[T] = str, section: Optional[str] = None) -> T:
        section = self.DEFAULT_SECTION if section is None else section
        if not self.parser.has_section(section):
            raise KeyError(section)
        if not self.parser.has_option(section, option):
            raise KeyError(option)
        return self.cast(self.parser.get(section, option), cast)

    def read_from_section(
        self, section: str, option: str, default: Any = None, cast: Type = str, save: bool = True
    ) -> str:
        value = self.parser.get(section, option)
        if not save:
            # Set to default or remove to not read it at trainig start again
            if default is None:
                self.parser.remove_option(section, option)
            elif not self.allow_defaults:
                raise ValueError(f"Value '{option}' not found in config (defaults not allowed).")
            else:
                self.parser.set(section, option, self.tostr(default, cast))
        elif section.lower() != section:
            self.parser.set(section.lower(), option, self.tostr(value, cast))
            self.parser.remove_option(section, option)
            self.modified = True
        return value

    def overwrite(self, section: str, option: str, value: Any):
        if not self.parser.has_section(section):
            return ValueError(f"Section not found: '{section}'")
        if not self.parser.has_option(section, option):
            return ValueError(f"Option not found '{option}' in section '{section}'")
        self.modified = True
        cast = type(value)
        return self.parser.set(section, option, self.tostr(value, cast))

    def _fix_df(self):
        """Renaming of some groups/options for compatibility with old models."""
        if self.parser.has_section("deepfilternet") and self.parser.has_section("df"):
            sec_deepfilternet = self.parser["deepfilternet"]
            sec_df = self.parser["df"]
            if "df_order" in sec_deepfilternet:
                sec_df["df_order"] = sec_deepfilternet["df_order"]
                del sec_deepfilternet["df_order"]
            if "df_lookahead" in sec_deepfilternet:
                sec_df["df_lookahead"] = sec_deepfilternet["df_lookahead"]
                del sec_deepfilternet["df_lookahead"]

    def _fix_clc(self):
        """Renaming of some groups/options for compatibility with old models."""
        if (
            not self.parser.has_section("deepfilternet")
            and self.parser.has_section("train")
            and self.parser.get("train", "model") == "convgru5"
        ):
            self.overwrite("train", "model", "deepfilternet")
            self.parser.add_section("deepfilternet")
            self.parser["deepfilternet"] = self.parser["convgru"]
            del self.parser["convgru"]
        if not self.parser.has_section("df") and self.parser.has_section("clc"):
            self.parser["df"] = self.parser["clc"]
            del self.parser["clc"]
        for section in self.parser.sections():
            for k, v in self.parser[section].items():
                if "clc" in k.lower():
                    self.parser.set(section, k.lower().replace("clc", "df"), v)
                    del self.parser[section][k]

    def __repr__(self):
        msg = ""
        for section in self.parser.sections():
            msg += f"{section}:\n"
            for k, v in self.parser[section].items():
                msg += f"  {k}: {v}\n"
        return msg


config = Config()


class Csv(object):
    """
    Produces a csv parser that return a list of transformed elements. From python-decouple.
    """

    def __init__(
        self, cast: Type[T] = str, delimiter=",", strip=string.whitespace, post_process=list
    ):
        """
        Parameters:
        cast -- callable that transforms the item just before it's added to the list.
        delimiter -- string of delimiters chars passed to shlex.
        strip -- string of non-relevant characters to be passed to str.strip after the split.
        post_process -- callable to post process all casted values. Default is `list`.
        """
        self.cast: Type[T] = cast
        self.delimiter = delimiter
        self.strip = strip
        self.post_process = post_process

    def __call__(self, value: Union[str, Tuple[T], List[T]]) -> List[T]:
        """The actual transformation"""
        if isinstance(value, (tuple, list)):
            # if default value is a list
            value = "".join(str(v) + self.delimiter for v in value)[:-1]

        def transform(s):
            return self.cast(s.strip(self.strip))

        splitter = shlex(value, posix=True)
        splitter.whitespace = self.delimiter
        splitter.whitespace_split = True

        return self.post_process(transform(s) for s in splitter)