Spaces:
Runtime error
Runtime error
| import os | |
| import string | |
| from configparser import ConfigParser | |
| from shlex import shlex | |
| from typing import Any, List, Optional, Tuple, Type, TypeVar, Union | |
| from loguru import logger | |
| T = TypeVar("T") | |
| class DfParams: | |
| def __init__(self): | |
| # Sampling rate used for training | |
| self.sr: int = config("SR", cast=int, default=48_000, section="DF") | |
| # FFT size in samples | |
| self.fft_size: int = config("FFT_SIZE", cast=int, default=960, section="DF") | |
| # STFT Hop size in samples | |
| self.hop_size: int = config("HOP_SIZE", cast=int, default=480, section="DF") | |
| # Number of ERB bands | |
| self.nb_erb: int = config("NB_ERB", cast=int, default=32, section="DF") | |
| # Number of deep filtering bins; DF is applied from 0th to nb_df-th frequency bins | |
| self.nb_df: int = config("NB_DF", cast=int, default=96, section="DF") | |
| # Normalization decay factor; used for complex and erb features | |
| self.norm_tau: float = config("NORM_TAU", 1, float, section="DF") | |
| # Local SNR minimum value, ground truth will be truncated | |
| self.lsnr_max: int = config("LSNR_MAX", 35, int, section="DF") | |
| # Local SNR maximum value, ground truth will be truncated | |
| self.lsnr_min: int = config("LSNR_MIN", -15, int, section="DF") | |
| # Minimum number of frequency bins per ERB band | |
| self.min_nb_freqs = config("MIN_NB_ERB_FREQS", 2, int, section="DF") | |
| # Deep Filtering order | |
| self.df_order: int = config("DF_ORDER", cast=int, default=5, section="DF") | |
| # Deep Filtering look-ahead | |
| self.df_lookahead: int = config("DF_LOOKAHEAD", cast=int, default=0, section="DF") | |
| # Pad mode. By default, padding will be handled on the input side: | |
| # - `input`, which pads the input features passed to the model | |
| # - `output`, which pads the output spectrogram corresponding to `df_lookahead` | |
| self.pad_mode: str = config("PAD_MODE", default="input_specf", section="DF") | |
| class Config: | |
| """Adopted from python-decouple""" | |
| DEFAULT_SECTION = "settings" | |
| def __init__(self): | |
| self.parser: ConfigParser = None # type: ignore | |
| self.path = "" | |
| self.modified = False | |
| self.allow_defaults = True | |
| def load( | |
| self, path: Optional[str], config_must_exist=False, allow_defaults=True, allow_reload=False | |
| ): | |
| self.allow_defaults = allow_defaults | |
| if self.parser is not None and not allow_reload: | |
| raise ValueError("Config already loaded") | |
| self.parser = ConfigParser() | |
| self.path = path | |
| if path is not None and os.path.isfile(path): | |
| with open(path) as f: | |
| self.parser.read_file(f) | |
| else: | |
| if config_must_exist: | |
| raise ValueError(f"No config file found at '{path}'.") | |
| if not self.parser.has_section(self.DEFAULT_SECTION): | |
| self.parser.add_section(self.DEFAULT_SECTION) | |
| self._fix_clc() | |
| self._fix_df() | |
| def use_defaults(self): | |
| self.load(path=None, config_must_exist=False) | |
| def save(self, path: str): | |
| if not self.modified: | |
| logger.debug("Config not modified. No need to overwrite on disk.") | |
| return | |
| if self.parser is None: | |
| self.parser = ConfigParser() | |
| for section in self.parser.sections(): | |
| if len(self.parser[section]) == 0: | |
| self.parser.remove_section(section) | |
| with open(path, mode="w") as f: | |
| self.parser.write(f) | |
| def tostr(self, value, cast): | |
| if isinstance(cast, Csv) and isinstance(value, (tuple, list)): | |
| return "".join(str(v) + cast.delimiter for v in value)[:-1] | |
| return str(value) | |
| def set(self, option: str, value: T, cast: Type[T], section: Optional[str] = None) -> T: | |
| section = self.DEFAULT_SECTION if section is None else section | |
| section = section.lower() | |
| if not self.parser.has_section(section): | |
| self.parser.add_section(section) | |
| if self.parser.has_option(section, option): | |
| if value == self.cast(self.parser.get(section, option), cast): | |
| return value | |
| self.modified = True | |
| self.parser.set(section, option, self.tostr(value, cast)) | |
| return value | |
| def __call__( | |
| self, | |
| option: str, | |
| default: Any = None, | |
| cast: Type[T] = str, | |
| save: bool = True, | |
| section: Optional[str] = None, | |
| ) -> T: | |
| # Get value either from an ENV or from the .ini file | |
| section = self.DEFAULT_SECTION if section is None else section | |
| value = None | |
| if self.parser is None: | |
| raise ValueError("No configuration loaded") | |
| if not self.parser.has_section(section.lower()): | |
| self.parser.add_section(section.lower()) | |
| if option in os.environ: | |
| value = os.environ[option] | |
| if save: | |
| self.parser.set(section, option, self.tostr(value, cast)) | |
| elif self.parser.has_option(section, option): | |
| value = self.read_from_section(section, option, default, cast=cast, save=save) | |
| elif self.parser.has_option(section.lower(), option): | |
| value = self.read_from_section(section.lower(), option, default, cast=cast, save=save) | |
| elif self.parser.has_option(self.DEFAULT_SECTION, option): | |
| logger.warning( | |
| f"Couldn't find option {option} in section {section}. " | |
| "Falling back to default settings section." | |
| ) | |
| value = self.read_from_section(self.DEFAULT_SECTION, option, cast=cast, save=save) | |
| elif default is None: | |
| raise ValueError("Value {} not found.".format(option)) | |
| elif not self.allow_defaults and save: | |
| raise ValueError(f"Value '{option}' not found in config (defaults not allowed).") | |
| else: | |
| value = default | |
| if save: | |
| self.set(option, value, cast, section) | |
| return self.cast(value, cast) | |
| def cast(self, value, cast): | |
| # Do the casting to get the correct type | |
| if cast is bool: | |
| value = str(value).lower() | |
| if value in {"true", "yes", "y", "on", "1"}: | |
| return True # type: ignore | |
| elif value in {"false", "no", "n", "off", "0"}: | |
| return False # type: ignore | |
| raise ValueError("Parse error") | |
| return cast(value) | |
| def get(self, option: str, cast: Type[T] = str, section: Optional[str] = None) -> T: | |
| section = self.DEFAULT_SECTION if section is None else section | |
| if not self.parser.has_section(section): | |
| raise KeyError(section) | |
| if not self.parser.has_option(section, option): | |
| raise KeyError(option) | |
| return self.cast(self.parser.get(section, option), cast) | |
| def read_from_section( | |
| self, section: str, option: str, default: Any = None, cast: Type = str, save: bool = True | |
| ) -> str: | |
| value = self.parser.get(section, option) | |
| if not save: | |
| # Set to default or remove to not read it at trainig start again | |
| if default is None: | |
| self.parser.remove_option(section, option) | |
| elif not self.allow_defaults: | |
| raise ValueError(f"Value '{option}' not found in config (defaults not allowed).") | |
| else: | |
| self.parser.set(section, option, self.tostr(default, cast)) | |
| elif section.lower() != section: | |
| self.parser.set(section.lower(), option, self.tostr(value, cast)) | |
| self.parser.remove_option(section, option) | |
| self.modified = True | |
| return value | |
| def overwrite(self, section: str, option: str, value: Any): | |
| if not self.parser.has_section(section): | |
| return ValueError(f"Section not found: '{section}'") | |
| if not self.parser.has_option(section, option): | |
| return ValueError(f"Option not found '{option}' in section '{section}'") | |
| self.modified = True | |
| cast = type(value) | |
| return self.parser.set(section, option, self.tostr(value, cast)) | |
| def _fix_df(self): | |
| """Renaming of some groups/options for compatibility with old models.""" | |
| if self.parser.has_section("deepfilternet") and self.parser.has_section("df"): | |
| sec_deepfilternet = self.parser["deepfilternet"] | |
| sec_df = self.parser["df"] | |
| if "df_order" in sec_deepfilternet: | |
| sec_df["df_order"] = sec_deepfilternet["df_order"] | |
| del sec_deepfilternet["df_order"] | |
| if "df_lookahead" in sec_deepfilternet: | |
| sec_df["df_lookahead"] = sec_deepfilternet["df_lookahead"] | |
| del sec_deepfilternet["df_lookahead"] | |
| def _fix_clc(self): | |
| """Renaming of some groups/options for compatibility with old models.""" | |
| if ( | |
| not self.parser.has_section("deepfilternet") | |
| and self.parser.has_section("train") | |
| and self.parser.get("train", "model") == "convgru5" | |
| ): | |
| self.overwrite("train", "model", "deepfilternet") | |
| self.parser.add_section("deepfilternet") | |
| self.parser["deepfilternet"] = self.parser["convgru"] | |
| del self.parser["convgru"] | |
| if not self.parser.has_section("df") and self.parser.has_section("clc"): | |
| self.parser["df"] = self.parser["clc"] | |
| del self.parser["clc"] | |
| for section in self.parser.sections(): | |
| for k, v in self.parser[section].items(): | |
| if "clc" in k.lower(): | |
| self.parser.set(section, k.lower().replace("clc", "df"), v) | |
| del self.parser[section][k] | |
| def __repr__(self): | |
| msg = "" | |
| for section in self.parser.sections(): | |
| msg += f"{section}:\n" | |
| for k, v in self.parser[section].items(): | |
| msg += f" {k}: {v}\n" | |
| return msg | |
| config = Config() | |
| class Csv(object): | |
| """ | |
| Produces a csv parser that return a list of transformed elements. From python-decouple. | |
| """ | |
| def __init__( | |
| self, cast: Type[T] = str, delimiter=",", strip=string.whitespace, post_process=list | |
| ): | |
| """ | |
| Parameters: | |
| cast -- callable that transforms the item just before it's added to the list. | |
| delimiter -- string of delimiters chars passed to shlex. | |
| strip -- string of non-relevant characters to be passed to str.strip after the split. | |
| post_process -- callable to post process all casted values. Default is `list`. | |
| """ | |
| self.cast: Type[T] = cast | |
| self.delimiter = delimiter | |
| self.strip = strip | |
| self.post_process = post_process | |
| def __call__(self, value: Union[str, Tuple[T], List[T]]) -> List[T]: | |
| """The actual transformation""" | |
| if isinstance(value, (tuple, list)): | |
| # if default value is a list | |
| value = "".join(str(v) + self.delimiter for v in value)[:-1] | |
| def transform(s): | |
| return self.cast(s.strip(self.strip)) | |
| splitter = shlex(value, posix=True) | |
| splitter.whitespace = self.delimiter | |
| splitter.whitespace_split = True | |
| return self.post_process(transform(s) for s in splitter) | |