jimbozhang commited on
Commit
5e9bb10
1 Parent(s): 902f4c8

Init commit.

Browse files
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021-2023 Xiaomi Corporation
2
+
3
+ import gradio as gr
4
+ import torch
5
+ import torchaudio
6
+
7
+ from ced_model.feature_extraction_ced import CedFeatureExtractor
8
+ from ced_model.modeling_ced import CedForAudioClassification
9
+
10
+ model_path = "mispeech/ced-base"
11
+ feature_extractor = CedFeatureExtractor.from_pretrained(model_path)
12
+ model = CedForAudioClassification.from_pretrained(model_path)
13
+
14
+
15
+ def process(audio_path: str) -> str:
16
+ if audio_path is None:
17
+ return "No audio file uploaded."
18
+
19
+ global model
20
+ global label_maps
21
+ audio, sr = torchaudio.load(audio_path)
22
+ if sr != 16000:
23
+ return "Models are trained on 16khz, please sample your input to 16khz mono."
24
+
25
+ inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="pt")
26
+
27
+ with torch.no_grad():
28
+ logits = model(**inputs).logits
29
+
30
+ predicted_class_ids = torch.argmax(logits, dim=-1).item()
31
+ predicted_label = model.config.id2label[predicted_class_ids]
32
+
33
+ return predicted_label
34
+
35
+
36
+ iface_audio_file = gr.Interface(
37
+ fn=process,
38
+ inputs=gr.Audio(sources="upload", type="filepath", streaming=False),
39
+ outputs="text",
40
+ )
41
+ gr.close_all()
42
+ iface_audio_file.launch()
ced_model/__init__.py ADDED
File without changes
ced_model/configuration_ced.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Xiaomi Corporation and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ CED model configuration"""
16
+
17
+
18
+ from transformers import PretrainedConfig
19
+ from transformers.utils import logging
20
+ from transformers.utils.hub import cached_file
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ CED_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
+ "mispeech/ced-tiny": "https://huggingface.co/mispeech/ced-tiny/resolve/main/config.json",
26
+ }
27
+
28
+
29
+ class CedConfig(PretrainedConfig):
30
+ model_type = "ced"
31
+
32
+ r"""
33
+ Configuration class for the CED model.
34
+
35
+ Args:
36
+ name (str, optional, *optional*):
37
+ Name of the pre-defined configuration. Can be "ced-tiny", "ced-mini", "ced-small" or "ced-base".
38
+ attn_drop_rate (float, *optional*, defaults to 0.0):
39
+ Dropout probability for attention weights. Default to 0.0.
40
+ depth (int, *optional*, defaults to 12): Number of transformer layers. Default to 12.
41
+ drop_path_rate (float, *optional*, defaults to 0.0): Drop path is taken from timm. Default to 0.0.
42
+ drop_rate (float, *optional*, defaults to 0.0):
43
+ Dropout probability for input embeddings. Default to 0.0.
44
+ embed_dim (int, *optional*, defaults to 768):
45
+ Dimensionality of the audio patch embeddings. Default to 768.
46
+ eval_avg (str, *optional*, defaults to `"mean"`):
47
+ Type of pooling to use for evaluation. Can be "mean", "token", "dm" or "logit". Default to "mean".
48
+ mlp_ratio (float, *optional*, defaults to 4.0):
49
+ Ratio of hidden size in the feedforward layer to the embedding size. Default to 4.0.
50
+ num_heads (int, *optional*, defaults to 12): Number of attention heads. Default to 12.
51
+ outputdim (int, *optional*, defaults to 527): Dimensionality of the output. Default to 527.
52
+ patch_size (int, *optional*, defaults to 16): Size of the patches. Default to 16.
53
+ patch_stride (int, *optional*, defaults to 16): Stride of the patches. Default to 16.
54
+ pooling (str, *optional*, defaults to `"mean"`):
55
+ Type of pooling to use for the output. Can be "mean", "token", "dm" or "logit". Default to "mean".
56
+ qkv_bias (bool, *optional*, defaults to `True`):
57
+ Whether to include bias terms in the query, key and value projections. Default to True.
58
+ target_length (int, *optional*, defaults to 1012): Frames of an audio chunk. Default to 1012.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ name=None,
64
+ attn_drop_rate=0.0,
65
+ depth=12,
66
+ drop_path_rate=0.0,
67
+ drop_rate=0.0,
68
+ embed_dim=768,
69
+ eval_avg="mean",
70
+ mlp_ratio=4.0,
71
+ num_heads=12,
72
+ outputdim=527,
73
+ patch_size=16,
74
+ patch_stride=16,
75
+ pooling="mean",
76
+ qkv_bias=True,
77
+ target_length=1012,
78
+ **kwargs,
79
+ ):
80
+ r"""
81
+ TODO: Add docstring
82
+ """
83
+
84
+ super().__init__(**kwargs)
85
+
86
+ if name == "ced-tiny":
87
+ embed_dim = 192
88
+ num_heads = 3
89
+ elif name == "ced-mini":
90
+ embed_dim = 256
91
+ num_heads = 4
92
+ elif name == "ced-small":
93
+ embed_dim = 384
94
+ num_heads = 6
95
+ elif name == "ced-base":
96
+ embed_dim = 768
97
+ num_heads = 12
98
+ else:
99
+ logger.info("No model name specified for CedConfig, use default settings.")
100
+
101
+ assert pooling in ("mean", "token", "dm", "logit")
102
+ self.name = name
103
+ self.attn_drop_rate = attn_drop_rate
104
+ self.center = kwargs.get("center", True)
105
+ self.depth = depth
106
+ self.drop_path_rate = drop_path_rate
107
+ self.drop_rate = drop_rate
108
+ self.embed_dim = embed_dim
109
+ self.eval_avg = eval_avg
110
+ self.f_max = kwargs.get("f_max", 8000)
111
+ self.f_min = kwargs.get("f_min", 0)
112
+ self.hop_size = kwargs.get("hop_size", 160)
113
+ self.mlp_ratio = mlp_ratio
114
+ self.n_fft = kwargs.get("n_fft", 512)
115
+ self.n_mels = kwargs.get("n_mels", 64)
116
+ self.n_mels = kwargs.get("n_mels", 64)
117
+ self.num_heads = num_heads
118
+ self.outputdim = outputdim
119
+ self.pad_last = kwargs.get("pad_last", True)
120
+ self.patch_size = patch_size
121
+ self.patch_stride = patch_stride
122
+ self.pooling = pooling
123
+ self.qkv_bias = qkv_bias
124
+ self.target_length = target_length
125
+ self.win_size = kwargs.get("win_size", 512)
126
+
127
+ if self.outputdim == 527:
128
+ with open(
129
+ cached_file("topel/ConvNeXt-Tiny-AT", "class_labels_indices.csv"), "r"
130
+ ) as f:
131
+ self.id2label = {
132
+ int(line.split(",", maxsplit=3)[0]): line.split(",", maxsplit=3)[2]
133
+ .replace('"', "")
134
+ .strip("\n")
135
+ for line in f.readlines()[1:]
136
+ }
137
+ self.label2id = {v: k for k, v in self.id2label.items()}
138
+ else:
139
+ self.id2label = None
140
+ self.label2id = None
ced_model/feature_extraction_ced.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Xiaomi Corporation and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Feature extractor class for CED.
17
+ """
18
+
19
+ from typing import Optional, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torchaudio.transforms as audio_transforms
24
+
25
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
26
+ from transformers.feature_extraction_utils import BatchFeature
27
+ from transformers.utils import logging
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class CedFeatureExtractor(SequenceFeatureExtractor):
34
+ r"""
35
+ CedFeatureExtractor extracts Mel spectrogram features from audio signals.
36
+
37
+ Args:
38
+ f_min (int, *optional*, defaults to 0): Minimum frequency for the Mel filterbank.
39
+ sampling_rate (int, *optional*, defaults to 16000):
40
+ Sampling rate of the input audio signal.
41
+ win_size (int, *optional*, defaults to 512): Window size for the STFT.
42
+ center (bool, *optional*, defaults to `True`):
43
+ Whether to pad the signal on both sides to center it.
44
+ n_fft (int, *optional*, defaults to 512): Number of FFT points for the STFT.
45
+ f_max (int, optional, *optional*): Maximum frequency for the Mel filterbank.
46
+ hop_size (int, *optional*, defaults to 160): Hop size for the STFT.
47
+ feature_size (int, *optional*, defaults to 64): Number of Mel bands to generate.
48
+ padding_value (float, *optional*, defaults to 0.0): Value for padding.
49
+
50
+ Returns:
51
+ BatchFeature: A BatchFeature object containing the extracted features.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ f_min: int = 0,
57
+ sampling_rate: int = 16000,
58
+ win_size: int = 512,
59
+ center: bool = True,
60
+ n_fft: int = 512,
61
+ f_max: Optional[int] = None,
62
+ hop_size: int = 160,
63
+ feature_size: int = 64,
64
+ padding_value: float = 0.0,
65
+ **kwargs,
66
+ ):
67
+ super().__init__(
68
+ feature_size=feature_size,
69
+ sampling_rate=sampling_rate,
70
+ padding_value=padding_value,
71
+ **kwargs,
72
+ )
73
+ self.f_min = f_min
74
+ self.win_size = win_size
75
+ self.center = center
76
+ self.n_fft = n_fft
77
+ self.f_max = f_max
78
+ self.hop_size = hop_size
79
+
80
+ def __call__(
81
+ self,
82
+ x: Union[np.ndarray, torch.Tensor],
83
+ sampling_rate: Optional[int] = None,
84
+ return_tensors="pt",
85
+ ) -> BatchFeature:
86
+ r"""
87
+ Extracts Mel spectrogram features from an audio signal tensor.
88
+
89
+ Args:
90
+ x: Input audio signal tensor.
91
+
92
+ Returns:
93
+ BatchFeature: A dictionary containing the extracted features.
94
+ """
95
+ if sampling_rate is None:
96
+ sampling_rate = self.sampling_rate
97
+
98
+ if return_tensors != "pt":
99
+ raise NotImplementedError(
100
+ "Only return_tensors='pt' is currently supported."
101
+ )
102
+
103
+ mel_spectrogram = audio_transforms.MelSpectrogram(
104
+ f_min=self.f_min,
105
+ sample_rate=sampling_rate,
106
+ win_length=self.win_size,
107
+ center=self.center,
108
+ n_fft=self.n_fft,
109
+ f_max=self.f_max,
110
+ hop_length=self.hop_size,
111
+ n_mels=self.feature_size,
112
+ )
113
+ amplitude_to_db = audio_transforms.AmplitudeToDB(top_db=120)
114
+
115
+ x = torch.from_numpy(x).float() if isinstance(x, np.ndarray) else x.float()
116
+ if x.dim() == 1:
117
+ x = x.unsqueeze(0)
118
+
119
+ x = mel_spectrogram(x)
120
+ x = amplitude_to_db(x)
121
+ return BatchFeature({"input_values": x})
ced_model/modeling_ced.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Xiaomi Corporation and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch CED (Ced) model."""
16
+
17
+ import collections
18
+ import math
19
+ from functools import partial
20
+ from typing import Any, Callable, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+
26
+ from transformers.modeling_outputs import SequenceClassifierOutput
27
+ from transformers.modeling_utils import PreTrainedModel
28
+ from transformers.utils import (
29
+ add_code_sample_docstrings,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ logging,
33
+ )
34
+ from .configuration_ced import CedConfig
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ _CONFIG_FOR_DOC = "CedConfig"
40
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'Speech synthesizer'"
41
+ _SEQ_CLASS_EXPECTED_LOSS = 0.69
42
+
43
+ # Audio classification docstring
44
+ _SEQ_CLASS_CHECKPOINT = "mispeech/ced-tiny"
45
+
46
+
47
+ CED_PRETRAINED_MODEL_ARCHIVE_LIST = [
48
+ "mispeech/ced-tiny",
49
+ "mispeech/ced-mini",
50
+ "mispeech/ced-small",
51
+ "mispeech/ced-base",
52
+ # See all CED models at https://huggingface.co/models?search=mispeech%2Fced
53
+ ]
54
+
55
+
56
+ class CedPreTrainedModel(PreTrainedModel):
57
+ """
58
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
59
+ models.
60
+ """
61
+
62
+ config_class = CedConfig
63
+ base_model_prefix = "ced"
64
+ main_input_name = "input_values"
65
+ supports_gradient_checkpointing = True
66
+
67
+ def _init_weights(self, module):
68
+ """Initialize the weights"""
69
+ if isinstance(module, nn.Linear):
70
+ trunc_normal_(module.weight, std=0.02)
71
+ if module.bias is not None:
72
+ nn.init.zeros_(module.bias)
73
+ elif isinstance(module, nn.LayerNorm):
74
+ nn.init.constant_(module.bias, 0)
75
+ nn.init.constant_(module.weight, 1.0)
76
+
77
+
78
+ Conv_Kernel = Union[int, Tuple[int, int]]
79
+
80
+
81
+ def to_2tuple(x: Any) -> Tuple[Any, Any]:
82
+ if isinstance(x, collections.abc.Iterable):
83
+ return x
84
+ return (x, x)
85
+
86
+
87
+ class CedAudioPatchEmbed(nn.Module):
88
+ def __init__(
89
+ self,
90
+ input_size: Conv_Kernel = 224,
91
+ patch_size: Conv_Kernel = 16,
92
+ patch_stride: Conv_Kernel = 16,
93
+ in_chans: int = 1,
94
+ embed_dim: int = 768,
95
+ norm_layer: Optional[Callable] = None,
96
+ flatten: bool = False,
97
+ ):
98
+ super().__init__()
99
+ self.input_size = to_2tuple(input_size)
100
+ self.patch_size = to_2tuple(patch_size)
101
+ self.patch_stride = to_2tuple(patch_stride)
102
+ self.grid_size = (
103
+ self.input_size[0] // self.patch_stride[0],
104
+ self.input_size[1] // self.patch_stride[1],
105
+ )
106
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
107
+ self.flatten = flatten
108
+
109
+ self.proj = nn.Conv2d(
110
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride
111
+ )
112
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
113
+
114
+ def forward(self, x):
115
+ x = self.proj(x)
116
+ if self.flatten:
117
+ x = torch.permute(torch.flatten(x, 2, 3), (0, 2, 1))
118
+ x = self.norm(x)
119
+ return x
120
+
121
+
122
+ class CedAttention(nn.Module):
123
+ def __init__(
124
+ self,
125
+ dim,
126
+ num_heads=8,
127
+ qkv_bias=False,
128
+ attn_drop=0.0,
129
+ proj_drop=0.0,
130
+ causal: bool = False,
131
+ ):
132
+ super().__init__()
133
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
134
+ self.num_heads = num_heads
135
+ head_dim = dim // num_heads
136
+ self.scale = head_dim**-0.5
137
+
138
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
139
+ self.attn_drop = nn.Dropout(attn_drop)
140
+ self.proj = nn.Linear(dim, dim)
141
+ self.proj_drop = nn.Dropout(proj_drop)
142
+ self.causal = causal
143
+
144
+ def forward(self, x):
145
+ B, N, C = x.shape
146
+ qkv = (
147
+ self.qkv(x)
148
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
149
+ .permute(2, 0, 3, 1, 4)
150
+ )
151
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
152
+
153
+ attn = (q @ k.transpose(-2, -1)) * self.scale
154
+ # if mask is not None:
155
+ # # Mask is a tensor of shape [B, T, T]
156
+ # # Different from self.causal == True, the mask might be something like:
157
+ # # [False, False, True]
158
+ # # [False, False, True]
159
+ # # [True, True, True]
160
+ # # We use -inf to pad here, since if we would pad by any number, the entries at rows only containing
161
+ # # [True, True, True] would lead to weights such as: [0.33,0.33,0.33], which is not correct
162
+ # mask_value = torch.as_tensor(-float('inf'))
163
+ # print(mask.shape, attn.shape)
164
+ # attn = attn.masked_fill(mask, mask_value)
165
+ if self.causal:
166
+ mask_value = -torch.finfo(attn.dtype).max
167
+ i, j = attn.shape[-2:]
168
+ mask = torch.ones(i, j, device=q.device, dtype=torch.bool).triu(j - i + 1)
169
+ attn = attn.masked_fill(mask, mask_value)
170
+ attn = attn.softmax(dim=-1)
171
+ # Only for the case that a mask with all True entries on a row is passed.
172
+ # attn = torch.nan_to_num(attn)
173
+ attn = self.attn_drop(attn)
174
+
175
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
176
+ x = self.proj(x)
177
+ x = self.proj_drop(x)
178
+ return x
179
+
180
+
181
+ class CedMlp(nn.Module):
182
+ def __init__(
183
+ self,
184
+ in_features: int,
185
+ hidden_features: Optional[int] = None,
186
+ out_features: Optional[int] = None,
187
+ act_layer: Callable = nn.GELU,
188
+ drop: float = 0.0,
189
+ ):
190
+ super().__init__()
191
+ out_features = out_features or in_features
192
+ hidden_features = hidden_features or in_features
193
+ self.fc1 = nn.Linear(in_features, hidden_features)
194
+ self.act = act_layer()
195
+ self.fc2 = nn.Linear(hidden_features, out_features)
196
+ self.drop = nn.Dropout(drop)
197
+
198
+ def forward(self, x):
199
+ x = self.fc1(x)
200
+ x = self.act(x)
201
+ x = self.drop(x)
202
+ x = self.fc2(x)
203
+ x = self.drop(x)
204
+ return x
205
+
206
+
207
+ # Drop path is taken from Timm
208
+ # https://github.com/huggingface/pytorch-image-models/blob/7c67d6aca992f039eece0af5f7c29a43d48c00e4/timm/models/layers/drop.py#L155
209
+ class DropPath(nn.Module):
210
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
211
+
212
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
213
+ super(DropPath, self).__init__()
214
+ self.drop_prob = drop_prob
215
+ self.scale_by_keep = scale_by_keep
216
+
217
+ def forward(self, x):
218
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
219
+
220
+ def extra_repr(self):
221
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
222
+
223
+
224
+ def drop_path(
225
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
226
+ ):
227
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
228
+
229
+ This is the same as the DropConnect impl I (https://github.com/rwightman) created for EfficientNet, etc networks,
230
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
231
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
232
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
233
+ argument.
234
+
235
+ """
236
+ if drop_prob == 0.0 or not training:
237
+ return x
238
+ keep_prob = 1 - drop_prob
239
+ shape = (x.shape[0],) + (1,) * (
240
+ x.ndim - 1
241
+ ) # work with diff dim tensors, not just 2D ConvNets
242
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
243
+ if keep_prob > 0.0 and scale_by_keep:
244
+ random_tensor.div_(keep_prob)
245
+ return x * random_tensor
246
+
247
+
248
+ class CedBlock(nn.Module):
249
+ def __init__(
250
+ self,
251
+ dim,
252
+ num_heads,
253
+ mlp_ratio=4.0,
254
+ qkv_bias=False,
255
+ drop=0.0,
256
+ attn_drop=0.0,
257
+ drop_path=0.0,
258
+ act_layer: Callable = nn.GELU,
259
+ norm_layer: Callable = nn.LayerNorm,
260
+ attention_type: Callable = CedAttention,
261
+ attention_kwargs={},
262
+ **kwargs,
263
+ ):
264
+ super().__init__()
265
+ self.norm1 = norm_layer(dim)
266
+ self.attn = attention_type(
267
+ dim,
268
+ num_heads=num_heads,
269
+ qkv_bias=qkv_bias,
270
+ attn_drop=attn_drop,
271
+ proj_drop=drop,
272
+ **attention_kwargs,
273
+ )
274
+ self.ls1 = nn.Identity()
275
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
276
+
277
+ self.norm2 = norm_layer(dim)
278
+ self.mlp = CedMlp(
279
+ in_features=dim,
280
+ hidden_features=int(dim * mlp_ratio),
281
+ act_layer=act_layer,
282
+ drop=drop,
283
+ )
284
+ self.ls2 = nn.Identity()
285
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
286
+
287
+ def forward(self, x):
288
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
289
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
290
+ return x
291
+
292
+
293
+ # Taken from timm
294
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
295
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
296
+
297
+
298
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
299
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
300
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
301
+ def norm_cdf(x):
302
+ # Computes standard normal cumulative distribution function
303
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
304
+
305
+ with torch.no_grad():
306
+ # Values are generated by using a truncated uniform distribution and
307
+ # then using the inverse CDF for the normal distribution.
308
+ # Get upper and lower cdf values
309
+ l = norm_cdf((a - mean) / std)
310
+ u = norm_cdf((b - mean) / std)
311
+
312
+ # Uniformly fill tensor with values from [l, u], then translate to
313
+ # [2l-1, 2u-1].
314
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
315
+
316
+ # Use inverse cdf transform for normal distribution to get truncated
317
+ # standard normal
318
+ tensor.erfinv_()
319
+
320
+ # Transform to proper mean, std
321
+ tensor.mul_(std * math.sqrt(2.0))
322
+ tensor.add_(mean)
323
+
324
+ # Clamp to ensure it's in the proper range
325
+ tensor.clamp_(min=a, max=b)
326
+ return tensor
327
+
328
+
329
+ CED_START_DOCSTRING = r"""
330
+
331
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
332
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
333
+ etc.)
334
+
335
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
336
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
337
+ and behavior.
338
+
339
+ Parameters:
340
+ config ([`CedConfig`]): Model configuration class with all the parameters of the model.
341
+ Initializing with a config file does not load the weights associated with the model, only the
342
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
343
+ """
344
+
345
+ CED_INPUTS_DOCSTRING = r"""
346
+ Args:
347
+ input_values (`torch.FloatTensor` of shape `(batch_size, n_mels, sequence_length)`):
348
+ The sequence of audio features extracted from the audio signal. Can be obtained from a raw audio waveform
349
+ using `~transformers.CedFeatureExtractor.__call__`.
350
+ """
351
+
352
+
353
+ @add_start_docstrings(
354
+ "The bare Ced Model transformer outputting raw hidden-states without any specific head on top.",
355
+ CED_START_DOCSTRING,
356
+ )
357
+ class CedModel(CedPreTrainedModel):
358
+ def __init__(self, config: CedConfig) -> None:
359
+ super().__init__(config)
360
+ self.config = config
361
+ self.name = config.name
362
+
363
+ # Allowed length in number of frames, otherwise the positional embedding will throw an error
364
+ self.maximal_allowed_length = self.config.target_length
365
+
366
+ self.init_bn = torch.nn.BatchNorm2d(config.n_mels, momentum=0.01)
367
+
368
+ self.patch_embed = CedAudioPatchEmbed(
369
+ input_size=(config.n_mels, config.target_length),
370
+ embed_dim=config.embed_dim,
371
+ patch_size=config.patch_size,
372
+ flatten=False,
373
+ patch_stride=config.patch_stride,
374
+ )
375
+
376
+ self.time_pos_embed = nn.Parameter(
377
+ torch.randn(1, config.embed_dim, 1, self.patch_embed.grid_size[1]) * 0.02
378
+ )
379
+ self.freq_pos_embed = nn.Parameter(
380
+ torch.randn(1, config.embed_dim, self.patch_embed.grid_size[0], 1) * 0.02
381
+ )
382
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
383
+ act_layer = nn.GELU
384
+ dpr = [
385
+ x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)
386
+ ] # stochastic depth decay rule
387
+ self.pos_drop = nn.Dropout(p=config.drop_rate)
388
+ self.blocks = nn.Sequential(
389
+ *[
390
+ CedBlock(
391
+ dim=config.embed_dim,
392
+ num_heads=config.num_heads,
393
+ mlp_ratio=config.mlp_ratio,
394
+ qkv_bias=config.qkv_bias,
395
+ drop=config.drop_rate,
396
+ attn_drop=config.attn_drop_rate,
397
+ drop_path=dpr[i],
398
+ norm_layer=norm_layer,
399
+ act_layer=act_layer,
400
+ attention_type=CedAttention,
401
+ )
402
+ for i in range(config.depth)
403
+ ]
404
+ )
405
+ self.norm = norm_layer(config.embed_dim)
406
+
407
+ # Initialize weights and apply final processing
408
+ self.post_init()
409
+
410
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
411
+ x = self.patch_embed(x)
412
+ _, _, _, t = x.shape
413
+ x = x + self.time_pos_embed[:, :, :, :t]
414
+ x = (
415
+ x + self.freq_pos_embed[:, :, :, :]
416
+ ) # Just to support __getitem__ in posembed
417
+
418
+ # x = rearrange(x, 'b c f t -> b (f t) c')
419
+ x = torch.permute(torch.flatten(x, 2, 3), (0, 2, 1))
420
+
421
+ if self.config.pooling == "token":
422
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1)
423
+ cls_token = cls_token + self.token_pos_embed
424
+ x = torch.cat((cls_token, x), dim=1)
425
+ x = self.pos_drop(x)
426
+ x = self.blocks(x)
427
+ x = self.norm(x)
428
+ return x
429
+
430
+ def forward(self, input_values: torch.Tensor):
431
+ r"""
432
+ Runs a forward pass of the CED model as an audio encoder.
433
+ """
434
+ x = torch.unsqueeze(input_values, 1)
435
+
436
+ x = torch.permute(x, (0, 2, 1, 3))
437
+ x = self.init_bn(x)
438
+ x = torch.permute(x, (0, 2, 1, 3))
439
+
440
+ if x.shape[-1] > self.maximal_allowed_length:
441
+ splits = x.split(self.maximal_allowed_length, -1)
442
+
443
+ if splits[-1].shape[-1] < self.maximal_allowed_length:
444
+ if self.config.pad_last:
445
+ pad = torch.zeros(
446
+ *x.shape[:-1], self.maximal_allowed_length, device=x.device
447
+ )
448
+ pad[..., : splits[-1].shape[-1]] = splits[-1]
449
+ splits = torch.stack((*splits[:-1], pad), dim=0)
450
+ else:
451
+ splits = torch.stack(splits[:-1], dim=0)
452
+ else:
453
+ splits = torch.stack(splits[:-1], dim=0)
454
+ n_splits = len(splits)
455
+ x = torch.flatten(splits, 0, 1) # spl b c f t-> (spl b) c f t
456
+ x = self.forward_head(self.ced(x))
457
+ x = torch.reshape(
458
+ x, (n_splits, -1, self.outputdim)
459
+ ) # (spl b) d -> spl b d, spl=n_splits
460
+
461
+ if self.config.eval_avg == "mean":
462
+ x = x.mean(0)
463
+ elif self.config.eval_avg == "max":
464
+ x = x.max(0)[0]
465
+ else:
466
+ raise ValueError(f"Unknown Eval average function ({self.eval_avg})")
467
+ else:
468
+ x = self.forward_features(x)
469
+
470
+ return SequenceClassifierOutput(logits=x)
471
+
472
+
473
+ @add_start_docstrings(
474
+ """
475
+ Ced model with an audio classification head on top (a linear layer on top of the pooled output).
476
+ """,
477
+ CED_START_DOCSTRING,
478
+ )
479
+ class CedForAudioClassification(CedPreTrainedModel):
480
+ def __init__(self, config: CedConfig) -> None:
481
+ super().__init__(config)
482
+ self.config = config
483
+
484
+ self.encoder = CedModel(config)
485
+
486
+ # Classifier head
487
+ self.outputlayer = nn.Sequential(
488
+ nn.LayerNorm(config.embed_dim),
489
+ nn.Linear(config.embed_dim, config.outputdim),
490
+ )
491
+
492
+ # Initialize weights and apply final processing
493
+ self.post_init()
494
+
495
+ def forward_head(self, x: torch.Tensor) -> torch.Tensor:
496
+ if self.config.pooling == "token":
497
+ x = x[:, 0]
498
+ return self.outputlayer(x).sigmoid()
499
+ elif self.config.pooling == "mean":
500
+ x = x.mean(1)
501
+ return self.outputlayer(x).sigmoid()
502
+ elif self.config.pooling == "logit":
503
+ x = x.mean(1)
504
+ return self.outputlayer(x)
505
+ elif self.config.pooling == "dm":
506
+ # Unpack using the frequency dimension, which is constant
507
+ # 'b (f t) d -> b f t d', f=self.patch_embed.grid_size[0])
508
+ x = torch.reshape(
509
+ x, (x.shape[0], self.patch_embed.grid_size[0], -1, x.shape[3])
510
+ )
511
+
512
+ # First poolin frequency, then sigmoid the (B T D) output
513
+ x = self.outputlayer(x.mean(1)).sigmoid()
514
+ return x.mean(1)
515
+ else:
516
+ return x.mean(1)
517
+
518
+ @add_start_docstrings_to_model_forward(
519
+ CED_INPUTS_DOCSTRING.format("batch_size, sequence_length")
520
+ )
521
+ @add_code_sample_docstrings(
522
+ checkpoint=_SEQ_CLASS_CHECKPOINT,
523
+ output_type=SequenceClassifierOutput,
524
+ config_class=_CONFIG_FOR_DOC,
525
+ modality="audio",
526
+ model_cls="CedForAudioClassification",
527
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
528
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
529
+ )
530
+ def forward(
531
+ self, input_values: torch.Tensor, labels: Optional[torch.Tensor] = None
532
+ ):
533
+ """
534
+ Runs a forward pass of the CED model for audio classification task.
535
+
536
+ Examples:
537
+
538
+ ```python
539
+ >>> from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
540
+ >>> from datasets import load_dataset
541
+ >>> import torch
542
+
543
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
544
+ >>> dataset = dataset.sort("id")
545
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
546
+
547
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("mispeech/ced-tiny")
548
+ >>> model = AutoModelForAudioClassification.from_pretrained("mispeech/ced-tiny")
549
+
550
+ >>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
551
+
552
+ >>> with torch.no_grad():
553
+ ... logits = model(**inputs).logits
554
+
555
+ >>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
556
+ >>> predicted_label = model.config.id2label[predicted_class_ids]
557
+ >>> predicted_label
558
+ 'Speech synthesizer'
559
+ ```
560
+ """
561
+ last_hidden_states = self.encoder(input_values).logits
562
+ logits = self.forward_head(last_hidden_states)
563
+
564
+ if labels is not None:
565
+ loss_fct = nn.BCEWithLogitsLoss()
566
+ labels = nn.functional.one_hot(
567
+ labels, num_classes=self.config.outputdim
568
+ ).float()
569
+ loss = loss_fct(logits, labels)
570
+ else:
571
+ loss = None
572
+
573
+ return SequenceClassifierOutput(
574
+ logits=logits, loss=loss, hidden_states=last_hidden_states
575
+ )
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==2.1.1
2
+ torchaudio==2.1.1
3
+ transformers==4.35.2