File size: 5,542 Bytes
5e9bb10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# coding=utf-8
# Copyright 2023 Xiaomi Corporation and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" CED model configuration"""


from transformers import PretrainedConfig
from transformers.utils import logging
from transformers.utils.hub import cached_file

logger = logging.get_logger(__name__)

CED_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "mispeech/ced-tiny": "https://huggingface.co/mispeech/ced-tiny/resolve/main/config.json",
}


class CedConfig(PretrainedConfig):
    model_type = "ced"

    r"""
    Configuration class for the CED model.

    Args:
        name (str, optional, *optional*):
            Name of the pre-defined configuration. Can be "ced-tiny", "ced-mini", "ced-small" or "ced-base".
        attn_drop_rate (float, *optional*, defaults to 0.0):
            Dropout probability for attention weights. Default to 0.0.
        depth (int, *optional*, defaults to 12): Number of transformer layers. Default to 12.
        drop_path_rate (float, *optional*, defaults to 0.0): Drop path is taken from timm. Default to 0.0.
        drop_rate (float, *optional*, defaults to 0.0):
            Dropout probability for input embeddings. Default to 0.0.
        embed_dim (int, *optional*, defaults to 768):
            Dimensionality of the audio patch embeddings. Default to 768.
        eval_avg (str, *optional*, defaults to `"mean"`):
            Type of pooling to use for evaluation. Can be "mean", "token", "dm" or "logit". Default to "mean".
        mlp_ratio (float, *optional*, defaults to 4.0):
            Ratio of hidden size in the feedforward layer to the embedding size. Default to 4.0.
        num_heads (int, *optional*, defaults to 12): Number of attention heads. Default to 12.
        outputdim (int, *optional*, defaults to 527): Dimensionality of the output. Default to 527.
        patch_size (int, *optional*, defaults to 16): Size of the patches. Default to 16.
        patch_stride (int, *optional*, defaults to 16): Stride of the patches. Default to 16.
        pooling (str, *optional*, defaults to `"mean"`):
            Type of pooling to use for the output. Can be "mean", "token", "dm" or "logit". Default to "mean".
        qkv_bias (bool, *optional*, defaults to `True`):
            Whether to include bias terms in the query, key and value projections. Default to True.
        target_length (int, *optional*, defaults to 1012): Frames of an audio chunk. Default to 1012.
    """

    def __init__(
        self,
        name=None,
        attn_drop_rate=0.0,
        depth=12,
        drop_path_rate=0.0,
        drop_rate=0.0,
        embed_dim=768,
        eval_avg="mean",
        mlp_ratio=4.0,
        num_heads=12,
        outputdim=527,
        patch_size=16,
        patch_stride=16,
        pooling="mean",
        qkv_bias=True,
        target_length=1012,
        **kwargs,
    ):
        r"""
        TODO: Add docstring
        """

        super().__init__(**kwargs)

        if name == "ced-tiny":
            embed_dim = 192
            num_heads = 3
        elif name == "ced-mini":
            embed_dim = 256
            num_heads = 4
        elif name == "ced-small":
            embed_dim = 384
            num_heads = 6
        elif name == "ced-base":
            embed_dim = 768
            num_heads = 12
        else:
            logger.info("No model name specified for CedConfig, use default settings.")

        assert pooling in ("mean", "token", "dm", "logit")
        self.name = name
        self.attn_drop_rate = attn_drop_rate
        self.center = kwargs.get("center", True)
        self.depth = depth
        self.drop_path_rate = drop_path_rate
        self.drop_rate = drop_rate
        self.embed_dim = embed_dim
        self.eval_avg = eval_avg
        self.f_max = kwargs.get("f_max", 8000)
        self.f_min = kwargs.get("f_min", 0)
        self.hop_size = kwargs.get("hop_size", 160)
        self.mlp_ratio = mlp_ratio
        self.n_fft = kwargs.get("n_fft", 512)
        self.n_mels = kwargs.get("n_mels", 64)
        self.n_mels = kwargs.get("n_mels", 64)
        self.num_heads = num_heads
        self.outputdim = outputdim
        self.pad_last = kwargs.get("pad_last", True)
        self.patch_size = patch_size
        self.patch_stride = patch_stride
        self.pooling = pooling
        self.qkv_bias = qkv_bias
        self.target_length = target_length
        self.win_size = kwargs.get("win_size", 512)

        if self.outputdim == 527:
            with open(
                cached_file("topel/ConvNeXt-Tiny-AT", "class_labels_indices.csv"), "r"
            ) as f:
                self.id2label = {
                    int(line.split(",", maxsplit=3)[0]): line.split(",", maxsplit=3)[2]
                    .replace('"', "")
                    .strip("\n")
                    for line in f.readlines()[1:]
                }
            self.label2id = {v: k for k, v in self.id2label.items()}
        else:
            self.id2label = None
            self.label2id = None