binwang commited on
Commit
b0f7c85
1 Parent(s): ed57245

Upload 76 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. __init__.py +0 -0
  3. __pycache__/__init__.cpython-310.pyc +0 -0
  4. __pycache__/model.cpython-310.pyc +0 -0
  5. __pycache__/model.cpython-39.pyc +0 -0
  6. beats/BEATs.py +180 -0
  7. beats/LICENSE_beats +21 -0
  8. beats/Tokenizers.py +172 -0
  9. beats/__init__.py +0 -0
  10. beats/__pycache__/BEATs.cpython-310.pyc +0 -0
  11. beats/__pycache__/BEATs.cpython-39.pyc +0 -0
  12. beats/__pycache__/__init__.cpython-310.pyc +0 -0
  13. beats/__pycache__/__init__.cpython-39.pyc +0 -0
  14. beats/__pycache__/backbone.cpython-310.pyc +0 -0
  15. beats/__pycache__/backbone.cpython-39.pyc +0 -0
  16. beats/__pycache__/modules.cpython-310.pyc +0 -0
  17. beats/__pycache__/modules.cpython-39.pyc +0 -0
  18. beats/backbone.py +783 -0
  19. beats/modules.py +218 -0
  20. beats/quantizer.py +215 -0
  21. beats_path/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt +3 -0
  22. ckpt_path/salmonn_7b_v0.pth +3 -0
  23. ckpt_path/salmonn_v1.pth +3 -0
  24. model.py +262 -0
  25. qformer/LICENSE_Lavis +14 -0
  26. qformer/LICENSE_MiniGPT4 +14 -0
  27. qformer/LICENSE_VideoLlama +28 -0
  28. qformer/Qformer.py +1217 -0
  29. qformer/__pycache__/Qformer.cpython-310.pyc +0 -0
  30. qformer/__pycache__/Qformer.cpython-39.pyc +0 -0
  31. requirements.txt +10 -0
  32. resource/audio_demo/duck.wav +0 -0
  33. resource/audio_demo/excitement.wav +0 -0
  34. resource/audio_demo/gunshots.wav +0 -0
  35. resource/audio_demo/mountain.wav +0 -0
  36. resource/audio_demo/music.wav +0 -0
  37. resource/response_demo/aac.png +0 -0
  38. resource/response_demo/aed.png +0 -0
  39. resource/response_demo/asr.png +0 -0
  40. resource/response_demo/emo.png +0 -0
  41. resource/response_demo/jsac.png +0 -0
  42. resource/response_demo/lyrics.png +0 -0
  43. resource/response_demo/mc.png +0 -0
  44. resource/response_demo/memo.png +0 -0
  45. resource/response_demo/pr.png +0 -0
  46. resource/response_demo/sac.png +0 -0
  47. resource/response_demo/sq.png +0 -0
  48. resource/response_demo/sr.png +0 -0
  49. resource/response_demo/story.png +0 -0
  50. resource/response_demo/title.png +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ resource/salmon.png filter=lfs diff=lfs merge=lfs -text
__init__.py ADDED
File without changes
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (171 Bytes). View file
 
__pycache__/model.cpython-310.pyc ADDED
Binary file (6 kB). View file
 
__pycache__/model.cpython-39.pyc ADDED
Binary file (5.88 kB). View file
 
beats/BEATs.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import LayerNorm
14
+ import torchaudio.compliance.kaldi as ta_kaldi
15
+
16
+ from beats.backbone import (
17
+ TransformerEncoder,
18
+ )
19
+
20
+ import logging
21
+ from typing import Optional
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class BEATsConfig:
27
+ def __init__(self, cfg=None):
28
+ self.input_patch_size: int = -1 # path size of patch embedding
29
+ self.embed_dim: int = 512 # patch embedding dimension
30
+ self.conv_bias: bool = False # include bias in conv encoder
31
+
32
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
33
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
34
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
35
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
36
+ self.activation_fn: str = "gelu" # activation function to use
37
+
38
+ self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
39
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
40
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
41
+
42
+ # dropouts
43
+ self.dropout: float = 0.1 # dropout probability for the transformer
44
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
45
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
46
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
47
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
48
+
49
+ # positional embeddings
50
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
51
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
52
+
53
+ # relative position embedding
54
+ self.relative_position_embedding: bool = False # apply relative position embedding
55
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
56
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
57
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
58
+
59
+ # label predictor
60
+ self.finetuned_model: bool = False # whether the model is a fine-tuned model.
61
+ self.predictor_dropout: float = 0.1 # dropout probability for the predictor
62
+ self.predictor_class: int = 527 # target class number for the predictor
63
+
64
+ if cfg is not None:
65
+ self.update(cfg)
66
+
67
+ def update(self, cfg: dict):
68
+ self.__dict__.update(cfg)
69
+
70
+
71
+ class BEATs(nn.Module):
72
+ def __init__(
73
+ self,
74
+ cfg: BEATsConfig,
75
+ ) -> None:
76
+ super().__init__()
77
+ logger.info(f"BEATs Config: {cfg.__dict__}")
78
+
79
+ self.cfg = cfg
80
+
81
+ self.embed = cfg.embed_dim
82
+ self.post_extract_proj = (
83
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
84
+ if self.embed != cfg.encoder_embed_dim
85
+ else None
86
+ )
87
+
88
+ self.input_patch_size = cfg.input_patch_size
89
+ self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
90
+ bias=cfg.conv_bias)
91
+
92
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
93
+
94
+ assert not cfg.deep_norm or not cfg.layer_norm_first
95
+ self.encoder = TransformerEncoder(cfg)
96
+ self.layer_norm = LayerNorm(self.embed)
97
+
98
+ if cfg.finetuned_model:
99
+ self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
100
+ self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
101
+ else:
102
+ self.predictor = None
103
+
104
+ def forward_padding_mask(
105
+ self,
106
+ features: torch.Tensor,
107
+ padding_mask: torch.Tensor,
108
+ ) -> torch.Tensor:
109
+ extra = padding_mask.size(1) % features.size(1)
110
+ if extra > 0:
111
+ padding_mask = padding_mask[:, :-extra]
112
+ padding_mask = padding_mask.view(
113
+ padding_mask.size(0), features.size(1), -1
114
+ )
115
+ padding_mask = padding_mask.all(-1)
116
+ return padding_mask
117
+
118
+ def preprocess(
119
+ self,
120
+ source: torch.Tensor,
121
+ fbank_mean: float = 15.41663,
122
+ fbank_std: float = 6.55582,
123
+ ) -> torch.Tensor:
124
+ fbanks = []
125
+ for waveform in source:
126
+ waveform = waveform.unsqueeze(0) * 2 ** 15
127
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
128
+ fbanks.append(fbank)
129
+ fbank = torch.stack(fbanks, dim=0)
130
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
131
+ return fbank
132
+
133
+ def extract_features(
134
+ self,
135
+ source: torch.Tensor,
136
+ padding_mask: Optional[torch.Tensor] = None,
137
+ fbank_mean: float = 15.41663,
138
+ fbank_std: float = 6.55582,
139
+ feature_only=False,
140
+ ):
141
+ fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(torch.float32)
142
+
143
+ if padding_mask is not None:
144
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
145
+
146
+ fbank = fbank.unsqueeze(1)
147
+ features = self.patch_embedding(fbank)
148
+ features = features.reshape(features.shape[0], features.shape[1], -1)
149
+ features = features.transpose(1, 2)
150
+ features = self.layer_norm(features)
151
+
152
+ if padding_mask is not None:
153
+ padding_mask = self.forward_padding_mask(features, padding_mask)
154
+
155
+ if self.post_extract_proj is not None:
156
+ features = self.post_extract_proj(features)
157
+
158
+ x = self.dropout_input(features)
159
+
160
+ x, layer_results = self.encoder(
161
+ x,
162
+ padding_mask=padding_mask,
163
+ )
164
+
165
+ if not feature_only and self.predictor is not None:
166
+ x = self.predictor_dropout(x)
167
+ logits = self.predictor(x)
168
+
169
+ if padding_mask is not None and padding_mask.any():
170
+ logits[padding_mask] = 0
171
+ logits = logits.sum(dim=1)
172
+ logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
173
+ else:
174
+ logits = logits.mean(dim=1)
175
+
176
+ lprobs = torch.sigmoid(logits)
177
+
178
+ return lprobs, padding_mask
179
+ else:
180
+ return x, padding_mask
beats/LICENSE_beats ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) Microsoft Corporation
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
beats/Tokenizers.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import LayerNorm
14
+ import torchaudio.compliance.kaldi as ta_kaldi
15
+
16
+ from beats.backbone import (
17
+ TransformerEncoder,
18
+ )
19
+ from beats.quantizer import (
20
+ NormEMAVectorQuantizer,
21
+ )
22
+
23
+ import logging
24
+ from typing import Optional
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class TokenizersConfig:
30
+ def __init__(self, cfg=None):
31
+ self.input_patch_size: int = -1 # path size of patch embedding
32
+ self.embed_dim: int = 512 # patch embedding dimension
33
+ self.conv_bias: bool = False # include bias in conv encoder
34
+
35
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
36
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
37
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
38
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
39
+ self.activation_fn: str = "gelu" # activation function to use
40
+
41
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
42
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
43
+
44
+ # dropouts
45
+ self.dropout: float = 0.1 # dropout probability for the transformer
46
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
47
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
48
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
49
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
50
+
51
+ # positional embeddings
52
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
53
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
54
+
55
+ # relative position embedding
56
+ self.relative_position_embedding: bool = False # apply relative position embedding
57
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
58
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
59
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
60
+
61
+ # quantizer
62
+ self.quant_n: int = 1024 # codebook number in quantizer
63
+ self.quant_dim: int = 256 # codebook dimension in quantizer
64
+
65
+ if cfg is not None:
66
+ self.update(cfg)
67
+
68
+ def update(self, cfg: dict):
69
+ self.__dict__.update(cfg)
70
+
71
+
72
+ class Tokenizers(nn.Module):
73
+ def __init__(
74
+ self,
75
+ cfg: TokenizersConfig,
76
+ ) -> None:
77
+ super().__init__()
78
+ logger.info(f"Tokenizers Config: {cfg.__dict__}")
79
+
80
+ self.cfg = cfg
81
+
82
+ self.embed = cfg.embed_dim
83
+ self.post_extract_proj = (
84
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
85
+ if self.embed != cfg.encoder_embed_dim
86
+ else None
87
+ )
88
+
89
+ self.input_patch_size = cfg.input_patch_size
90
+ self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
91
+ bias=cfg.conv_bias)
92
+
93
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
94
+
95
+ assert not cfg.deep_norm or not cfg.layer_norm_first
96
+ self.encoder = TransformerEncoder(cfg)
97
+ self.layer_norm = LayerNorm(self.embed)
98
+
99
+ self.quantize = NormEMAVectorQuantizer(
100
+ n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
101
+ )
102
+ self.quant_n = cfg.quant_n
103
+ self.quantize_layer = nn.Sequential(
104
+ nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
105
+ nn.Tanh(),
106
+ nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
107
+ )
108
+
109
+ def forward_padding_mask(
110
+ self,
111
+ features: torch.Tensor,
112
+ padding_mask: torch.Tensor,
113
+ ) -> torch.Tensor:
114
+ extra = padding_mask.size(1) % features.size(1)
115
+ if extra > 0:
116
+ padding_mask = padding_mask[:, :-extra]
117
+ padding_mask = padding_mask.view(
118
+ padding_mask.size(0), features.size(1), -1
119
+ )
120
+ padding_mask = padding_mask.all(-1)
121
+ return padding_mask
122
+
123
+ def preprocess(
124
+ self,
125
+ source: torch.Tensor,
126
+ fbank_mean: float = 15.41663,
127
+ fbank_std: float = 6.55582,
128
+ ) -> torch.Tensor:
129
+ fbanks = []
130
+ for waveform in source:
131
+ waveform = waveform.unsqueeze(0) * 2 ** 15
132
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
133
+ fbanks.append(fbank)
134
+ fbank = torch.stack(fbanks, dim=0)
135
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
136
+ return fbank
137
+
138
+ def extract_labels(
139
+ self,
140
+ source: torch.Tensor,
141
+ padding_mask: Optional[torch.Tensor] = None,
142
+ fbank_mean: float = 15.41663,
143
+ fbank_std: float = 6.55582,
144
+ ):
145
+ fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
146
+
147
+ if padding_mask is not None:
148
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
149
+
150
+ fbank = fbank.unsqueeze(1)
151
+ features = self.patch_embedding(fbank)
152
+ features = features.reshape(features.shape[0], features.shape[1], -1)
153
+ features = features.transpose(1, 2)
154
+ features = self.layer_norm(features)
155
+
156
+ if padding_mask is not None:
157
+ padding_mask = self.forward_padding_mask(features, padding_mask)
158
+
159
+ if self.post_extract_proj is not None:
160
+ features = self.post_extract_proj(features)
161
+
162
+ x = self.dropout_input(features)
163
+
164
+ x, layer_results = self.encoder(
165
+ x,
166
+ padding_mask=padding_mask,
167
+ )
168
+
169
+ quantize_input = self.quantize_layer(x)
170
+ quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
171
+
172
+ return embed_ind
beats/__init__.py ADDED
File without changes
beats/__pycache__/BEATs.cpython-310.pyc ADDED
Binary file (4.2 kB). View file
 
beats/__pycache__/BEATs.cpython-39.pyc ADDED
Binary file (4.15 kB). View file
 
beats/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (177 Bytes). View file
 
beats/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (156 Bytes). View file
 
beats/__pycache__/backbone.cpython-310.pyc ADDED
Binary file (16.7 kB). View file
 
beats/__pycache__/backbone.cpython-39.pyc ADDED
Binary file (16.5 kB). View file
 
beats/__pycache__/modules.cpython-310.pyc ADDED
Binary file (6.19 kB). View file
 
beats/__pycache__/modules.cpython-39.pyc ADDED
Binary file (6.15 kB). View file
 
beats/backbone.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import numpy as np
12
+ from typing import Dict, Optional, Tuple
13
+ import torch
14
+ from torch import Tensor, nn
15
+ import torch.nn.functional as F
16
+ from torch.nn import LayerNorm, Parameter
17
+ from beats.modules import (
18
+ GradMultiply,
19
+ SamePad,
20
+ get_activation_fn,
21
+ GLU_Linear,
22
+ quant_noise,
23
+ )
24
+
25
+
26
+ class TransformerEncoder(nn.Module):
27
+ def __init__(self, args):
28
+ super().__init__()
29
+
30
+ self.dropout = args.dropout
31
+ self.embedding_dim = args.encoder_embed_dim
32
+
33
+ self.pos_conv = nn.Conv1d(
34
+ self.embedding_dim,
35
+ self.embedding_dim,
36
+ kernel_size=args.conv_pos,
37
+ padding=args.conv_pos // 2,
38
+ groups=args.conv_pos_groups,
39
+ )
40
+ dropout = 0
41
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
42
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
43
+ nn.init.constant_(self.pos_conv.bias, 0)
44
+
45
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
46
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
47
+
48
+ if hasattr(args, "relative_position_embedding"):
49
+ self.relative_position_embedding = args.relative_position_embedding
50
+ self.num_buckets = args.num_buckets
51
+ self.max_distance = args.max_distance
52
+ else:
53
+ self.relative_position_embedding = False
54
+ self.num_buckets = 0
55
+ self.max_distance = 0
56
+
57
+ self.layers = nn.ModuleList(
58
+ [
59
+ TransformerSentenceEncoderLayer(
60
+ embedding_dim=self.embedding_dim,
61
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
62
+ num_attention_heads=args.encoder_attention_heads,
63
+ dropout=self.dropout,
64
+ attention_dropout=args.attention_dropout,
65
+ activation_dropout=args.activation_dropout,
66
+ activation_fn=args.activation_fn,
67
+ layer_norm_first=args.layer_norm_first,
68
+ deep_norm=args.deep_norm,
69
+ has_relative_attention_bias=self.relative_position_embedding,
70
+ num_buckets=self.num_buckets,
71
+ max_distance=self.max_distance,
72
+ gru_rel_pos=args.gru_rel_pos,
73
+ encoder_layers=args.encoder_layers,
74
+ )
75
+ for i in range(args.encoder_layers)
76
+ ]
77
+ )
78
+ if self.relative_position_embedding:
79
+ for i in range(1, args.encoder_layers):
80
+ del self.layers[i].self_attn.relative_attention_bias
81
+ self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
82
+
83
+ self.layer_norm_first = args.layer_norm_first
84
+ self.layer_norm = LayerNorm(self.embedding_dim)
85
+ self.layerdrop = args.encoder_layerdrop
86
+
87
+ self.apply(init_bert_params)
88
+
89
+ if args.deep_norm:
90
+ deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
91
+ for i in range(args.encoder_layers):
92
+ nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
93
+ nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
94
+ nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
95
+ nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
96
+ nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
97
+ nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
98
+
99
+ self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
100
+
101
+ def forward(self, x, padding_mask=None, layer=None):
102
+ x, layer_results = self.extract_features(x, padding_mask, layer)
103
+
104
+ if self.layer_norm_first and layer is None:
105
+ x = self.layer_norm(x)
106
+
107
+ return x, layer_results
108
+
109
+ def extract_features(self, x, padding_mask=None, tgt_layer=None):
110
+
111
+ if padding_mask is not None:
112
+ x[padding_mask] = 0
113
+
114
+ x_conv = self.pos_conv(x.transpose(1, 2))
115
+ x_conv = x_conv.transpose(1, 2)
116
+ x = x + x_conv
117
+
118
+ if not self.layer_norm_first:
119
+ x = self.layer_norm(x)
120
+
121
+ x = F.dropout(x, p=self.dropout, training=self.training)
122
+
123
+ # B x T x C -> T x B x C
124
+ x = x.transpose(0, 1)
125
+
126
+ layer_results = []
127
+ z = None
128
+ if tgt_layer is not None:
129
+ layer_results.append((x, z))
130
+ r = None
131
+ pos_bias = None
132
+ for i, layer in enumerate(self.layers):
133
+ if self.layer_wise_gradient_decay_ratio != 1.0:
134
+ x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
135
+ dropout_probability = np.random.random()
136
+ if not self.training or (dropout_probability > self.layerdrop):
137
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
138
+ if tgt_layer is not None:
139
+ layer_results.append((x, z))
140
+ if i == tgt_layer:
141
+ r = x
142
+ break
143
+
144
+ if r is not None:
145
+ x = r
146
+
147
+ # T x B x C -> B x T x C
148
+ x = x.transpose(0, 1)
149
+
150
+ return x, layer_results
151
+
152
+
153
+ class TransformerSentenceEncoderLayer(nn.Module):
154
+ def __init__(
155
+ self,
156
+ embedding_dim: float = 768,
157
+ ffn_embedding_dim: float = 3072,
158
+ num_attention_heads: float = 8,
159
+ dropout: float = 0.1,
160
+ attention_dropout: float = 0.1,
161
+ activation_dropout: float = 0.1,
162
+ activation_fn: str = "relu",
163
+ layer_norm_first: bool = False,
164
+ deep_norm: bool = False,
165
+ has_relative_attention_bias: bool = False,
166
+ num_buckets: int = 0,
167
+ max_distance: int = 0,
168
+ rescale_init: bool = False,
169
+ gru_rel_pos: bool = False,
170
+ encoder_layers: int = 0,
171
+ ) -> None:
172
+
173
+ super().__init__()
174
+ self.embedding_dim = embedding_dim
175
+ self.dropout = dropout
176
+ self.activation_dropout = activation_dropout
177
+
178
+ self.activation_name = activation_fn
179
+ self.activation_fn = get_activation_fn(activation_fn)
180
+ self.self_attn = MultiheadAttention(
181
+ self.embedding_dim,
182
+ num_attention_heads,
183
+ dropout=attention_dropout,
184
+ self_attention=True,
185
+ has_relative_attention_bias=has_relative_attention_bias,
186
+ num_buckets=num_buckets,
187
+ max_distance=max_distance,
188
+ rescale_init=rescale_init,
189
+ gru_rel_pos=gru_rel_pos,
190
+ )
191
+
192
+ self.dropout1 = nn.Dropout(dropout)
193
+ self.dropout2 = nn.Dropout(self.activation_dropout)
194
+ self.dropout3 = nn.Dropout(dropout)
195
+
196
+ self.layer_norm_first = layer_norm_first
197
+
198
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
199
+
200
+ if self.activation_name == "glu":
201
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
202
+ else:
203
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
204
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
205
+
206
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
207
+
208
+ self.deep_norm = deep_norm
209
+ if self.deep_norm:
210
+ self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
211
+ else:
212
+ self.deep_norm_alpha = 1
213
+
214
+ def forward(
215
+ self,
216
+ x: torch.Tensor,
217
+ self_attn_mask: torch.Tensor = None,
218
+ self_attn_padding_mask: torch.Tensor = None,
219
+ need_weights: bool = False,
220
+ pos_bias=None
221
+ ):
222
+ residual = x
223
+
224
+ if self.layer_norm_first:
225
+ x = self.self_attn_layer_norm(x)
226
+ x, attn, pos_bias = self.self_attn(
227
+ query=x,
228
+ key=x,
229
+ value=x,
230
+ key_padding_mask=self_attn_padding_mask,
231
+ need_weights=False,
232
+ attn_mask=self_attn_mask,
233
+ position_bias=pos_bias
234
+ )
235
+ x = self.dropout1(x)
236
+ x = residual + x
237
+
238
+ residual = x
239
+ x = self.final_layer_norm(x)
240
+ if self.activation_name == "glu":
241
+ x = self.fc1(x)
242
+ else:
243
+ x = self.activation_fn(self.fc1(x))
244
+ x = self.dropout2(x)
245
+ x = self.fc2(x)
246
+ x = self.dropout3(x)
247
+ x = residual + x
248
+ else:
249
+ x, attn, pos_bias = self.self_attn(
250
+ query=x,
251
+ key=x,
252
+ value=x,
253
+ key_padding_mask=self_attn_padding_mask,
254
+ need_weights=need_weights,
255
+ attn_mask=self_attn_mask,
256
+ position_bias=pos_bias
257
+ )
258
+
259
+ x = self.dropout1(x)
260
+ x = residual * self.deep_norm_alpha + x
261
+
262
+ x = self.self_attn_layer_norm(x)
263
+
264
+ residual = x
265
+ if self.activation_name == "glu":
266
+ x = self.fc1(x)
267
+ else:
268
+ x = self.activation_fn(self.fc1(x))
269
+ x = self.dropout2(x)
270
+ x = self.fc2(x)
271
+ x = self.dropout3(x)
272
+ x = residual * self.deep_norm_alpha + x
273
+ x = self.final_layer_norm(x)
274
+
275
+ return x, attn, pos_bias
276
+
277
+
278
+ class MultiheadAttention(nn.Module):
279
+ """Multi-headed attention.
280
+
281
+ See "Attention Is All You Need" for more details.
282
+ """
283
+
284
+ def __init__(
285
+ self,
286
+ embed_dim,
287
+ num_heads,
288
+ kdim=None,
289
+ vdim=None,
290
+ dropout=0.0,
291
+ bias=True,
292
+ add_bias_kv=False,
293
+ add_zero_attn=False,
294
+ self_attention=False,
295
+ encoder_decoder_attention=False,
296
+ q_noise=0.0,
297
+ qn_block_size=8,
298
+ has_relative_attention_bias=False,
299
+ num_buckets=32,
300
+ max_distance=128,
301
+ gru_rel_pos=False,
302
+ rescale_init=False,
303
+ ):
304
+ super().__init__()
305
+ self.embed_dim = embed_dim
306
+ self.kdim = kdim if kdim is not None else embed_dim
307
+ self.vdim = vdim if vdim is not None else embed_dim
308
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
309
+
310
+ self.num_heads = num_heads
311
+ self.dropout_module = nn.Dropout(dropout)
312
+
313
+ self.has_relative_attention_bias = has_relative_attention_bias
314
+ self.num_buckets = num_buckets
315
+ self.max_distance = max_distance
316
+ if self.has_relative_attention_bias:
317
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
318
+
319
+ self.head_dim = embed_dim // num_heads
320
+ self.q_head_dim = self.head_dim
321
+ self.k_head_dim = self.head_dim
322
+ assert (
323
+ self.head_dim * num_heads == self.embed_dim
324
+ ), "embed_dim must be divisible by num_heads"
325
+ self.scaling = self.head_dim ** -0.5
326
+
327
+ self.self_attention = self_attention
328
+ self.encoder_decoder_attention = encoder_decoder_attention
329
+
330
+ assert not self.self_attention or self.qkv_same_dim, (
331
+ "Self-attention requires query, key and " "value to be of the same size"
332
+ )
333
+
334
+ k_bias = True
335
+ if rescale_init:
336
+ k_bias = False
337
+
338
+ k_embed_dim = embed_dim
339
+ q_embed_dim = embed_dim
340
+
341
+ self.k_proj = quant_noise(
342
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
343
+ )
344
+ self.v_proj = quant_noise(
345
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
346
+ )
347
+ self.q_proj = quant_noise(
348
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
349
+ )
350
+
351
+ self.out_proj = quant_noise(
352
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
353
+ )
354
+
355
+ if add_bias_kv:
356
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
357
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
358
+ else:
359
+ self.bias_k = self.bias_v = None
360
+
361
+ self.add_zero_attn = add_zero_attn
362
+
363
+ self.gru_rel_pos = gru_rel_pos
364
+ if self.gru_rel_pos:
365
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
366
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
367
+
368
+ self.reset_parameters()
369
+
370
+ def reset_parameters(self):
371
+ if self.qkv_same_dim:
372
+ # Empirically observed the convergence to be much better with
373
+ # the scaled initialization
374
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
375
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
376
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
377
+ else:
378
+ nn.init.xavier_uniform_(self.k_proj.weight)
379
+ nn.init.xavier_uniform_(self.v_proj.weight)
380
+ nn.init.xavier_uniform_(self.q_proj.weight)
381
+
382
+ nn.init.xavier_uniform_(self.out_proj.weight)
383
+ if self.out_proj.bias is not None:
384
+ nn.init.constant_(self.out_proj.bias, 0.0)
385
+ if self.bias_k is not None:
386
+ nn.init.xavier_normal_(self.bias_k)
387
+ if self.bias_v is not None:
388
+ nn.init.xavier_normal_(self.bias_v)
389
+ if self.has_relative_attention_bias:
390
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
391
+
392
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
393
+ num_buckets = self.num_buckets
394
+ max_distance = self.max_distance
395
+ relative_buckets = 0
396
+
397
+ if bidirectional:
398
+ num_buckets = num_buckets // 2
399
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
400
+ relative_positions = torch.abs(relative_positions)
401
+ else:
402
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
403
+
404
+ max_exact = num_buckets // 2
405
+ is_small = relative_positions < max_exact
406
+
407
+ relative_postion_if_large = max_exact + (
408
+ torch.log(relative_positions.float() / max_exact)
409
+ / math.log(max_distance / max_exact)
410
+ * (num_buckets - max_exact)
411
+ ).to(torch.long)
412
+ relative_postion_if_large = torch.min(
413
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
414
+ )
415
+
416
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
417
+ return relative_buckets
418
+
419
+ def compute_bias(self, query_length, key_length):
420
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
421
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
422
+ relative_position = memory_position - context_position
423
+ relative_position_bucket = self._relative_positions_bucket(
424
+ relative_position,
425
+ bidirectional=True
426
+ )
427
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
428
+ values = self.relative_attention_bias(relative_position_bucket)
429
+ values = values.permute([2, 0, 1])
430
+ return values
431
+
432
+ def forward(
433
+ self,
434
+ query,
435
+ key: Optional[Tensor],
436
+ value: Optional[Tensor],
437
+ key_padding_mask: Optional[Tensor] = None,
438
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
439
+ need_weights: bool = True,
440
+ static_kv: bool = False,
441
+ attn_mask: Optional[Tensor] = None,
442
+ before_softmax: bool = False,
443
+ need_head_weights: bool = False,
444
+ position_bias: Optional[Tensor] = None
445
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
446
+ """Input shape: Time x Batch x Channel
447
+
448
+ Args:
449
+ key_padding_mask (ByteTensor, optional): mask to exclude
450
+ keys that are pads, of shape `(batch, src_len)`, where
451
+ padding elements are indicated by 1s.
452
+ need_weights (bool, optional): return the attention weights,
453
+ averaged over heads (default: False).
454
+ attn_mask (ByteTensor, optional): typically used to
455
+ implement causal attention, where the mask prevents the
456
+ attention from looking forward in time (default: None).
457
+ before_softmax (bool, optional): return the raw attention
458
+ weights and values before the attention softmax.
459
+ need_head_weights (bool, optional): return the attention
460
+ weights for each head. Implies *need_weights*. Default:
461
+ return the average attention weights over all heads.
462
+ """
463
+ if need_head_weights:
464
+ need_weights = True
465
+
466
+ is_tpu = query.device.type == "xla"
467
+
468
+ tgt_len, bsz, embed_dim = query.size()
469
+ src_len = tgt_len
470
+ assert embed_dim == self.embed_dim
471
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
472
+ if key is not None:
473
+ src_len, key_bsz, _ = key.size()
474
+ if not torch.jit.is_scripting():
475
+ assert key_bsz == bsz
476
+ assert value is not None
477
+ assert src_len, bsz == value.shape[:2]
478
+
479
+ if self.has_relative_attention_bias and position_bias is None:
480
+ position_bias = self.compute_bias(tgt_len, src_len)
481
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
482
+
483
+ if incremental_state is not None:
484
+ saved_state = self._get_input_buffer(incremental_state)
485
+ if saved_state is not None and "prev_key" in saved_state:
486
+ # previous time steps are cached - no need to recompute
487
+ # key and value if they are static
488
+ if static_kv:
489
+ assert self.encoder_decoder_attention and not self.self_attention
490
+ key = value = None
491
+ else:
492
+ saved_state = None
493
+
494
+ if self.self_attention:
495
+ q = self.q_proj(query)
496
+ k = self.k_proj(query)
497
+ v = self.v_proj(query)
498
+ elif self.encoder_decoder_attention:
499
+ # encoder-decoder attention
500
+ q = self.q_proj(query)
501
+ if key is None:
502
+ assert value is None
503
+ k = v = None
504
+ else:
505
+ k = self.k_proj(key)
506
+ v = self.v_proj(key)
507
+
508
+ else:
509
+ assert key is not None and value is not None
510
+ q = self.q_proj(query)
511
+ k = self.k_proj(key)
512
+ v = self.v_proj(value)
513
+ q *= self.scaling
514
+ alpha = 32
515
+ q *= 1 / alpha
516
+
517
+ if self.bias_k is not None:
518
+ assert self.bias_v is not None
519
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
520
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
521
+ if attn_mask is not None:
522
+ attn_mask = torch.cat(
523
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
524
+ )
525
+ if key_padding_mask is not None:
526
+ key_padding_mask = torch.cat(
527
+ [
528
+ key_padding_mask,
529
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
530
+ ],
531
+ dim=1,
532
+ )
533
+
534
+ q = (
535
+ q.contiguous()
536
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
537
+ .transpose(0, 1)
538
+ )
539
+ if k is not None:
540
+ k = (
541
+ k.contiguous()
542
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
543
+ .transpose(0, 1)
544
+ )
545
+ if v is not None:
546
+ v = (
547
+ v.contiguous()
548
+ .view(-1, bsz * self.num_heads, self.head_dim)
549
+ .transpose(0, 1)
550
+ )
551
+
552
+ if saved_state is not None:
553
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
554
+ if "prev_key" in saved_state:
555
+ _prev_key = saved_state["prev_key"]
556
+ assert _prev_key is not None
557
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
558
+ if static_kv:
559
+ k = prev_key
560
+ else:
561
+ assert k is not None
562
+ k = torch.cat([prev_key, k], dim=1)
563
+ src_len = k.size(1)
564
+ if "prev_value" in saved_state:
565
+ _prev_value = saved_state["prev_value"]
566
+ assert _prev_value is not None
567
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
568
+ if static_kv:
569
+ v = prev_value
570
+ else:
571
+ assert v is not None
572
+ v = torch.cat([prev_value, v], dim=1)
573
+ prev_key_padding_mask: Optional[Tensor] = None
574
+ if "prev_key_padding_mask" in saved_state:
575
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
576
+ assert k is not None and v is not None
577
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
578
+ key_padding_mask=key_padding_mask,
579
+ prev_key_padding_mask=prev_key_padding_mask,
580
+ batch_size=bsz,
581
+ src_len=k.size(1),
582
+ static_kv=static_kv,
583
+ )
584
+
585
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
586
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
587
+ saved_state["prev_key_padding_mask"] = key_padding_mask
588
+ # In this branch incremental_state is never None
589
+ assert incremental_state is not None
590
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
591
+ assert k is not None
592
+ assert k.size(1) == src_len
593
+
594
+ # This is part of a workaround to get around fork/join parallelism
595
+ # not supporting Optional types.
596
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
597
+ key_padding_mask = None
598
+
599
+ if key_padding_mask is not None:
600
+ assert key_padding_mask.size(0) == bsz
601
+ assert key_padding_mask.size(1) == src_len
602
+
603
+ if self.add_zero_attn:
604
+ assert v is not None
605
+ src_len += 1
606
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
607
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
608
+ if attn_mask is not None:
609
+ attn_mask = torch.cat(
610
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
611
+ )
612
+ if key_padding_mask is not None:
613
+ key_padding_mask = torch.cat(
614
+ [
615
+ key_padding_mask,
616
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
617
+ key_padding_mask
618
+ ),
619
+ ],
620
+ dim=1,
621
+ )
622
+
623
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
624
+ attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
625
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
626
+
627
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
628
+
629
+ if attn_mask is not None:
630
+ attn_mask = attn_mask.unsqueeze(0)
631
+ attn_weights += attn_mask
632
+
633
+ if key_padding_mask is not None:
634
+ # don't attend to padding symbols
635
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
636
+ if not is_tpu:
637
+ attn_weights = attn_weights.masked_fill(
638
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
639
+ float("-inf"),
640
+ )
641
+ else:
642
+ attn_weights = attn_weights.transpose(0, 2)
643
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
644
+ attn_weights = attn_weights.transpose(0, 2)
645
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
646
+
647
+ if before_softmax:
648
+ return attn_weights, v, position_bias
649
+
650
+ if position_bias is not None:
651
+ attn_mask_rel_pos = position_bias
652
+ if self.gru_rel_pos == 1:
653
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
654
+ _B, _H, _L, __ = query_layer.size()
655
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
656
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
657
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
658
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
659
+
660
+ attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
661
+
662
+ attn_weights = attn_weights + attn_mask_rel_pos
663
+
664
+ attn_weights_float = F.softmax(
665
+ attn_weights, dim=-1
666
+ )
667
+ attn_weights = attn_weights_float.type_as(attn_weights)
668
+ attn_probs = self.dropout_module(attn_weights)
669
+
670
+ assert v is not None
671
+ attn = torch.bmm(attn_probs, v)
672
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
673
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
674
+ attn = self.out_proj(attn)
675
+ attn_weights: Optional[Tensor] = None
676
+ if need_weights:
677
+ attn_weights = attn_weights_float.view(
678
+ bsz, self.num_heads, tgt_len, src_len
679
+ ).transpose(1, 0)
680
+ if not need_head_weights:
681
+ # average attention weights over heads
682
+ attn_weights = attn_weights.mean(dim=0)
683
+
684
+ return attn, attn_weights, position_bias
685
+
686
+ @staticmethod
687
+ def _append_prev_key_padding_mask(
688
+ key_padding_mask: Optional[Tensor],
689
+ prev_key_padding_mask: Optional[Tensor],
690
+ batch_size: int,
691
+ src_len: int,
692
+ static_kv: bool,
693
+ ) -> Optional[Tensor]:
694
+ # saved key padding masks have shape (bsz, seq_len)
695
+ if prev_key_padding_mask is not None and static_kv:
696
+ new_key_padding_mask = prev_key_padding_mask
697
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
698
+ new_key_padding_mask = torch.cat(
699
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
700
+ )
701
+ # During incremental decoding, as the padding token enters and
702
+ # leaves the frame, there will be a time when prev or current
703
+ # is None
704
+ elif prev_key_padding_mask is not None:
705
+ if src_len > prev_key_padding_mask.size(1):
706
+ filler = torch.zeros(
707
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
708
+ device=prev_key_padding_mask.device,
709
+ )
710
+ new_key_padding_mask = torch.cat(
711
+ [prev_key_padding_mask.float(), filler.float()], dim=1
712
+ )
713
+ else:
714
+ new_key_padding_mask = prev_key_padding_mask.float()
715
+ elif key_padding_mask is not None:
716
+ if src_len > key_padding_mask.size(1):
717
+ filler = torch.zeros(
718
+ (batch_size, src_len - key_padding_mask.size(1)),
719
+ device=key_padding_mask.device,
720
+ )
721
+ new_key_padding_mask = torch.cat(
722
+ [filler.float(), key_padding_mask.float()], dim=1
723
+ )
724
+ else:
725
+ new_key_padding_mask = key_padding_mask.float()
726
+ else:
727
+ new_key_padding_mask = prev_key_padding_mask
728
+ return new_key_padding_mask
729
+
730
+ def _get_input_buffer(
731
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
732
+ ) -> Dict[str, Optional[Tensor]]:
733
+ result = self.get_incremental_state(incremental_state, "attn_state")
734
+ if result is not None:
735
+ return result
736
+ else:
737
+ empty_result: Dict[str, Optional[Tensor]] = {}
738
+ return empty_result
739
+
740
+ def _set_input_buffer(
741
+ self,
742
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
743
+ buffer: Dict[str, Optional[Tensor]],
744
+ ):
745
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
746
+
747
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
748
+ return attn_weights
749
+
750
+
751
+ def init_bert_params(module):
752
+ """
753
+ Initialize the weights specific to the BERT Model.
754
+ This overrides the default initializations depending on the specified arguments.
755
+ 1. If normal_init_linear_weights is set then weights of linear
756
+ layer will be initialized using the normal distribution and
757
+ bais will be set to the specified value.
758
+ 2. If normal_init_embed_weights is set then weights of embedding
759
+ layer will be initialized using the normal distribution.
760
+ 3. If normal_init_proj_weights is set then weights of
761
+ in_project_weight for MultiHeadAttention initialized using
762
+ the normal distribution (to be validated).
763
+ """
764
+
765
+ def normal_(data):
766
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
767
+ # so that the RNG is consistent with and without FSDP
768
+ data.copy_(
769
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
770
+ )
771
+
772
+ if isinstance(module, nn.Linear):
773
+ normal_(module.weight.data)
774
+ if module.bias is not None:
775
+ module.bias.data.zero_()
776
+ if isinstance(module, nn.Embedding):
777
+ normal_(module.weight.data)
778
+ if module.padding_idx is not None:
779
+ module.weight.data[module.padding_idx].zero_()
780
+ if isinstance(module, MultiheadAttention):
781
+ normal_(module.q_proj.weight.data)
782
+ normal_(module.k_proj.weight.data)
783
+ normal_(module.v_proj.weight.data)
beats/modules.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ import torch
13
+ from torch import Tensor, nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ class GradMultiply(torch.autograd.Function):
18
+ @staticmethod
19
+ def forward(ctx, x, scale):
20
+ ctx.scale = scale
21
+ res = x.new(x)
22
+ return res
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad):
26
+ return grad * ctx.scale, None
27
+
28
+
29
+ class SamePad(nn.Module):
30
+ def __init__(self, kernel_size, causal=False):
31
+ super().__init__()
32
+ if causal:
33
+ self.remove = kernel_size - 1
34
+ else:
35
+ self.remove = 1 if kernel_size % 2 == 0 else 0
36
+
37
+ def forward(self, x):
38
+ if self.remove > 0:
39
+ x = x[:, :, : -self.remove]
40
+ return x
41
+
42
+
43
+ class Swish(nn.Module):
44
+ def __init__(self):
45
+ super(Swish, self).__init__()
46
+ self.act = torch.nn.Sigmoid()
47
+
48
+ def forward(self, x):
49
+ return x * self.act(x)
50
+
51
+
52
+ class GLU_Linear(nn.Module):
53
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
54
+ super(GLU_Linear, self).__init__()
55
+
56
+ self.glu_type = glu_type
57
+ self.output_dim = output_dim
58
+
59
+ if glu_type == "sigmoid":
60
+ self.glu_act = torch.nn.Sigmoid()
61
+ elif glu_type == "swish":
62
+ self.glu_act = Swish()
63
+ elif glu_type == "relu":
64
+ self.glu_act = torch.nn.ReLU()
65
+ elif glu_type == "gelu":
66
+ self.glu_act = torch.nn.GELU()
67
+
68
+ if bias_in_glu:
69
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
70
+ else:
71
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
72
+
73
+ def forward(self, x):
74
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
75
+ x = self.linear(x)
76
+
77
+ if self.glu_type == "bilinear":
78
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
79
+ else:
80
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
81
+
82
+ return x
83
+
84
+
85
+ def gelu_accurate(x):
86
+ if not hasattr(gelu_accurate, "_a"):
87
+ gelu_accurate._a = math.sqrt(2 / math.pi)
88
+ return (
89
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
90
+ )
91
+
92
+
93
+ def gelu(x: torch.Tensor) -> torch.Tensor:
94
+ return torch.nn.functional.gelu(x.float()).type_as(x)
95
+
96
+
97
+ def get_activation_fn(activation: str):
98
+ """Returns the activation function corresponding to `activation`"""
99
+
100
+ if activation == "relu":
101
+ return F.relu
102
+ elif activation == "gelu":
103
+ return gelu
104
+ elif activation == "gelu_fast":
105
+ warnings.warn(
106
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
107
+ )
108
+ return gelu_accurate
109
+ elif activation == "gelu_accurate":
110
+ return gelu_accurate
111
+ elif activation == "tanh":
112
+ return torch.tanh
113
+ elif activation == "linear":
114
+ return lambda x: x
115
+ elif activation == "glu":
116
+ return lambda x: x
117
+ else:
118
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
119
+
120
+
121
+ def quant_noise(module, p, block_size):
122
+ """
123
+ Wraps modules and applies quantization noise to the weights for
124
+ subsequent quantization with Iterative Product Quantization as
125
+ described in "Training with Quantization Noise for Extreme Model Compression"
126
+
127
+ Args:
128
+ - module: nn.Module
129
+ - p: amount of Quantization Noise
130
+ - block_size: size of the blocks for subsequent quantization with iPQ
131
+
132
+ Remarks:
133
+ - Module weights must have the right sizes wrt the block size
134
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
135
+ - For more detail on how to quantize by blocks with convolutional weights,
136
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
137
+ - We implement the simplest form of noise here as stated in the paper
138
+ which consists in randomly dropping blocks
139
+ """
140
+
141
+ # if no quantization noise, don't register hook
142
+ if p <= 0:
143
+ return module
144
+
145
+ # supported modules
146
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
147
+
148
+ # test whether module.weight has the right sizes wrt block_size
149
+ is_conv = module.weight.ndim == 4
150
+
151
+ # 2D matrix
152
+ if not is_conv:
153
+ assert (
154
+ module.weight.size(1) % block_size == 0
155
+ ), "Input features must be a multiple of block sizes"
156
+
157
+ # 4D matrix
158
+ else:
159
+ # 1x1 convolutions
160
+ if module.kernel_size == (1, 1):
161
+ assert (
162
+ module.in_channels % block_size == 0
163
+ ), "Input channels must be a multiple of block sizes"
164
+ # regular convolutions
165
+ else:
166
+ k = module.kernel_size[0] * module.kernel_size[1]
167
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
168
+
169
+ def _forward_pre_hook(mod, input):
170
+ # no noise for evaluation
171
+ if mod.training:
172
+ if not is_conv:
173
+ # gather weight and sizes
174
+ weight = mod.weight
175
+ in_features = weight.size(1)
176
+ out_features = weight.size(0)
177
+
178
+ # split weight matrix into blocks and randomly drop selected blocks
179
+ mask = torch.zeros(
180
+ in_features // block_size * out_features, device=weight.device
181
+ )
182
+ mask.bernoulli_(p)
183
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
184
+
185
+ else:
186
+ # gather weight and sizes
187
+ weight = mod.weight
188
+ in_channels = mod.in_channels
189
+ out_channels = mod.out_channels
190
+
191
+ # split weight matrix into blocks and randomly drop selected blocks
192
+ if mod.kernel_size == (1, 1):
193
+ mask = torch.zeros(
194
+ int(in_channels // block_size * out_channels),
195
+ device=weight.device,
196
+ )
197
+ mask.bernoulli_(p)
198
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
199
+ else:
200
+ mask = torch.zeros(
201
+ weight.size(0), weight.size(1), device=weight.device
202
+ )
203
+ mask.bernoulli_(p)
204
+ mask = (
205
+ mask.unsqueeze(2)
206
+ .unsqueeze(3)
207
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
208
+ )
209
+
210
+ # scale weights and apply mask
211
+ mask = mask.to(
212
+ torch.bool
213
+ ) # x.bool() is not currently supported in TorchScript
214
+ s = 1 / (1 - p)
215
+ mod.weight.data = s * weight.masked_fill(mask, 0)
216
+
217
+ module.register_forward_pre_hook(_forward_pre_hook)
218
+ return module
beats/quantizer.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on VQGAN code bases
7
+ # https://github.com/CompVis/taming-transformers
8
+ # --------------------------------------------------------'
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.distributed as distributed
14
+
15
+ try:
16
+ from einops import rearrange, repeat
17
+ except ImportError:
18
+ pass
19
+
20
+
21
+ def l2norm(t):
22
+ return F.normalize(t, p=2, dim=-1)
23
+
24
+
25
+ def ema_inplace(moving_avg, new, decay):
26
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
27
+
28
+
29
+ def sample_vectors(samples, num):
30
+ num_samples, device = samples.shape[0], samples.device
31
+
32
+ if num_samples >= num:
33
+ indices = torch.randperm(num_samples, device=device)[:num]
34
+ else:
35
+ indices = torch.randint(0, num_samples, (num,), device=device)
36
+
37
+ return samples[indices]
38
+
39
+
40
+ def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
41
+ dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
42
+
43
+ means = sample_vectors(samples, num_clusters)
44
+
45
+ for _ in range(num_iters):
46
+ if use_cosine_sim:
47
+ dists = samples @ means.t()
48
+ else:
49
+ diffs = rearrange(samples, 'n d -> n () d') \
50
+ - rearrange(means, 'c d -> () c d')
51
+ dists = -(diffs ** 2).sum(dim=-1)
52
+
53
+ buckets = dists.max(dim=-1).indices
54
+ bins = torch.bincount(buckets, minlength=num_clusters)
55
+ zero_mask = bins == 0
56
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
57
+
58
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
59
+ new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
60
+ new_means = new_means / bins_min_clamped[..., None]
61
+
62
+ if use_cosine_sim:
63
+ new_means = l2norm(new_means)
64
+
65
+ means = torch.where(zero_mask[..., None], means, new_means)
66
+
67
+ return means, bins
68
+
69
+
70
+ class EmbeddingEMA(nn.Module):
71
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''):
72
+ super().__init__()
73
+ self.num_tokens = num_tokens
74
+ self.codebook_dim = codebook_dim
75
+ self.decay = decay
76
+ self.eps = eps
77
+ if codebook_init_path == '':
78
+ if not kmeans_init:
79
+ weight = torch.randn(num_tokens, codebook_dim)
80
+ weight = l2norm(weight)
81
+ else:
82
+ weight = torch.zeros(num_tokens, codebook_dim)
83
+ self.register_buffer('initted', torch.Tensor([not kmeans_init]))
84
+ else:
85
+ print(f"load init codebook weight from {codebook_init_path}")
86
+ codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
87
+ weight = codebook_ckpt_weight.clone()
88
+ self.register_buffer('initted', torch.Tensor([True]))
89
+
90
+ self.weight = nn.Parameter(weight, requires_grad=False)
91
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
92
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
93
+ # self.register_buffer('initted', torch.Tensor([not kmeans_init]))
94
+ self.update = True
95
+
96
+ @torch.jit.ignore
97
+ def init_embed_(self, data):
98
+ if self.initted:
99
+ return
100
+ print("Performing Kemans init for codebook")
101
+ embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
102
+ self.weight.data.copy_(embed)
103
+ self.cluster_size.data.copy_(cluster_size)
104
+ self.initted.data.copy_(torch.Tensor([True]))
105
+
106
+ def forward(self, embed_id):
107
+ return F.embedding(embed_id, self.weight)
108
+
109
+ def cluster_size_ema_update(self, new_cluster_size):
110
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
111
+
112
+ def embed_avg_ema_update(self, new_embed_avg):
113
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
114
+
115
+ def weight_update(self, num_tokens):
116
+ n = self.cluster_size.sum()
117
+ smoothed_cluster_size = (
118
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
119
+ )
120
+ # normalize embedding average with smoothed cluster size
121
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
122
+ # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
123
+ self.weight.data.copy_(embed_normalized)
124
+
125
+
126
+ def norm_ema_inplace(moving_avg, new, decay):
127
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
128
+ moving_avg.data.copy_(l2norm(moving_avg.data))
129
+
130
+
131
+ class NormEMAVectorQuantizer(nn.Module):
132
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
133
+ statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
134
+ super().__init__()
135
+ self.codebook_dim = embedding_dim
136
+ self.num_tokens = n_embed
137
+ self.beta = beta
138
+ self.decay = decay
139
+
140
+ # learnable = True if orthogonal_reg_weight > 0 else False
141
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
142
+
143
+ self.statistic_code_usage = statistic_code_usage
144
+ if statistic_code_usage:
145
+ self.register_buffer('cluster_size', torch.zeros(n_embed))
146
+ if distributed.is_available() and distributed.is_initialized():
147
+ print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
148
+ self.all_reduce_fn = distributed.all_reduce
149
+ else:
150
+ self.all_reduce_fn = nn.Identity()
151
+
152
+ def reset_cluster_size(self, device):
153
+ if self.statistic_code_usage:
154
+ self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
155
+ self.cluster_size = self.cluster_size.to(device)
156
+
157
+ def forward(self, z):
158
+ # reshape z -> (batch, height, width, channel) and flatten
159
+ # z, 'b c h w -> b h w c'
160
+ # z = rearrange(z, 'b c h w -> b h w c')
161
+ # z = z.transpose(1, 2)
162
+ z = l2norm(z)
163
+ z_flattened = z.reshape(-1, self.codebook_dim)
164
+
165
+ self.embedding.init_embed_(z_flattened)
166
+
167
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
168
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
169
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
170
+
171
+ encoding_indices = torch.argmin(d, dim=1)
172
+
173
+ z_q = self.embedding(encoding_indices).view(z.shape)
174
+
175
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
176
+
177
+ if not self.training:
178
+ with torch.no_grad():
179
+ cluster_size = encodings.sum(0)
180
+ self.all_reduce_fn(cluster_size)
181
+ ema_inplace(self.cluster_size, cluster_size, self.decay)
182
+
183
+ if self.training and self.embedding.update:
184
+ # EMA cluster size
185
+
186
+ bins = encodings.sum(0)
187
+ self.all_reduce_fn(bins)
188
+
189
+ # self.embedding.cluster_size_ema_update(bins)
190
+ ema_inplace(self.cluster_size, bins, self.decay)
191
+
192
+ zero_mask = (bins == 0)
193
+ bins = bins.masked_fill(zero_mask, 1.)
194
+
195
+ embed_sum = z_flattened.t() @ encodings
196
+ self.all_reduce_fn(embed_sum)
197
+
198
+ embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
199
+ embed_normalized = l2norm(embed_normalized)
200
+
201
+ embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
202
+ embed_normalized)
203
+ norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
204
+
205
+ # compute loss for embedding
206
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
207
+
208
+ # preserve gradients
209
+ z_q = z + (z_q - z).detach()
210
+
211
+ # reshape back to match original input shape
212
+ # z_q, 'b h w c -> b c h w'
213
+ # z_q = rearrange(z_q, 'b h w c -> b c h w')
214
+ # z_q = z_q.transpose(1, 2)
215
+ return z_q, loss, encoding_indices
beats_path/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5815275a04b6885e7b8af63d120b29bffae2cd2225cf4915e1ec6d819d3022c
3
+ size 363145291
ckpt_path/salmonn_7b_v0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2cb2782495b2e3f487222763a30b53b02f727d49059201cc5fa88a7a1fd9dff9
3
+ size 362638989
ckpt_path/salmonn_v1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:709c665b25ef05b48985584ec31d6f15018b754abf47b9c33ed9a278285bbae0
3
+ size 400466533
model.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2023) Tsinghua University, Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import soundfile as sf
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from peft import LoraConfig, TaskType, get_peft_model
21
+ from transformers import (
22
+ WhisperFeatureExtractor,
23
+ WhisperModel,
24
+ LlamaForCausalLM,
25
+ LlamaTokenizer
26
+ )
27
+ import librosa
28
+ import sys
29
+ sys.path.append('examples/SALMONN_7B/')
30
+
31
+ from beats.BEATs import BEATsConfig, BEATs
32
+ from qformer.Qformer import BertConfig, BertLMHeadModel
33
+
34
+ class SALMONN(nn.Module):
35
+ def __init__(
36
+ self,
37
+ ckpt,
38
+ whisper_path,
39
+ beats_path,
40
+ vicuna_path,
41
+ speech_qformer_token_num=1,
42
+ speech_qformer_layer=2,
43
+ lora=True,
44
+ lora_alpha=32,
45
+ lora_rank=8,
46
+ lora_dropout=0.1,
47
+ second_per_frame=0.333333,
48
+ second_stride=0.333333,
49
+ low_resource=False
50
+ ):
51
+
52
+ super().__init__()
53
+
54
+ # feature_extractor
55
+ self.feature_extractor = WhisperFeatureExtractor.from_pretrained(whisper_path)
56
+
57
+ # whisper
58
+ self.speech_encoder = WhisperModel.from_pretrained(whisper_path).encoder
59
+ self.ln_speech = nn.LayerNorm(self.speech_encoder.config.d_model)
60
+
61
+ # beats
62
+ self.beats_ckpt = beats_path
63
+ beats_checkpoint = torch.load(self.beats_ckpt, map_location='cpu')
64
+ beats_cfg = BEATsConfig(beats_checkpoint['cfg'])
65
+ beats = BEATs(beats_cfg)
66
+ beats.load_state_dict(beats_checkpoint['model'])
67
+ self.beats = beats
68
+ self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim)
69
+ for name, param in self.beats.named_parameters():
70
+ param.requires_grad = False
71
+ self.beats.eval()
72
+
73
+ # init speech Qformer
74
+ self.speech_Qformer, self.speech_query_tokens = self.init_speech_Qformer(
75
+ speech_qformer_token_num,
76
+ self.speech_encoder.config.d_model + self.beats.cfg.encoder_embed_dim,
77
+ speech_qformer_layer,
78
+ )
79
+ self.second_per_frame = second_per_frame
80
+ self.second_stride = second_stride
81
+
82
+ # vicuna
83
+ if not low_resource:
84
+ self.llama_model = LlamaForCausalLM.from_pretrained(
85
+ vicuna_path,
86
+ torch_dtype=torch.float32,
87
+ #torch_dtype=torch.float16,
88
+ device_map="auto",
89
+ )
90
+ else:
91
+ self.llama_model = LlamaForCausalLM.from_pretrained(
92
+ vicuna_path,
93
+ torch_dtype=torch.float16,
94
+ load_in_8bit=True,
95
+ device_map="auto"
96
+ )
97
+
98
+ # lora
99
+ self.lora = lora
100
+ if lora:
101
+ target_modules = None
102
+ self.peft_config = LoraConfig(
103
+ task_type=TaskType.CAUSAL_LM,
104
+ inference_mode=True,
105
+ r=lora_rank,
106
+ lora_alpha=lora_alpha,
107
+ lora_dropout=lora_dropout,
108
+ target_modules=target_modules,
109
+ )
110
+ self.llama_model = get_peft_model(self.llama_model, self.peft_config)
111
+
112
+ # tokenizer
113
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_path, use_fast=False)
114
+ self.llama_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
115
+ self.llama_tokenizer.padding_side = "right"
116
+
117
+ # proj
118
+ self.speech_llama_proj = nn.Linear(
119
+ self.speech_Qformer.config.hidden_size, self.llama_model.config.hidden_size)
120
+
121
+ # load ckpt
122
+ ckpt_dict = torch.load(ckpt)['model']
123
+ self.load_state_dict(ckpt_dict, strict=False)
124
+
125
+
126
+ def generate(
127
+ self,
128
+ prompt=None,
129
+ audio_array=None,
130
+ sampling_rate=None,
131
+ wav_path=None,
132
+ prompt_pattern="USER: <Speech><SpeechHere></Speech> {}\nASSISTANT:",
133
+ device='cuda',
134
+ max_length=300, # 150 as default, Bin changed to 300
135
+ num_beams=4,
136
+ do_sample=True,
137
+ min_length=1,
138
+ top_p=0.9,
139
+ repetition_penalty=1.0,
140
+ length_penalty=1.0,
141
+ temperature=1.0,
142
+ ):
143
+ if wav_path:
144
+ # read wav
145
+ wav, sr = sf.read(wav_path)
146
+ if len(wav.shape) == 2:
147
+ wav = wav[:, 0]
148
+ if len(wav) > 30 * sr:
149
+ wav = wav[: 30 * sr]
150
+ if sr != 16000:
151
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=16000, res_type="fft")
152
+ elif sampling_rate:
153
+ wav=librosa.resample(audio_array, orig_sr=sampling_rate, target_sr=16000, res_type="fft")
154
+
155
+ # whisper
156
+ spectrogram = self.feature_extractor(wav, return_tensors="pt", sampling_rate=16000).input_features.to(device) # [1, 80, 3000]
157
+ speech_embeds = self.speech_encoder(spectrogram, return_dict=True).last_hidden_state
158
+
159
+ # beats
160
+ raw_wav = torch.from_numpy(wav).to(device).unsqueeze(0)
161
+ audio_padding_mask = torch.zeros(raw_wav.shape, device=device).bool()
162
+ audio_embeds, _ = self.beats.extract_features(raw_wav, padding_mask=audio_padding_mask, feature_only=True)
163
+
164
+ # auditory embeds
165
+ speech_embeds = self.ln_speech(speech_embeds)
166
+ audio_embeds = self.ln_audio(audio_embeds)
167
+ audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1)))
168
+ speech_embeds = torch.cat([speech_embeds, audio_embeds], dim=-1)
169
+
170
+ # split frames
171
+ B, T, C = speech_embeds.shape
172
+ kernel = round(T * self.second_per_frame / 30.0)
173
+ stride = round(T * self.second_stride / 30.0)
174
+ kernel = (1, kernel)
175
+ stride = (1, stride)
176
+ speech_embeds_tr = speech_embeds.transpose(1, 2).unsqueeze(2)
177
+ speech_embeds_overlap = F.unfold(speech_embeds_tr, kernel_size=kernel, dilation=1, padding=0, stride=stride)
178
+ _, _, L = speech_embeds_overlap.shape
179
+ speech_embeds_overlap = speech_embeds_overlap.view(B, -1, kernel[1], L)
180
+ speech_embeds_overlap = torch.permute(speech_embeds_overlap, [0, 3, 2, 1])
181
+ speech_embeds = speech_embeds_overlap.reshape(-1, kernel[1], C)
182
+ speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long, device=speech_embeds.device)
183
+
184
+ # Qformer
185
+ query_tokens = self.speech_query_tokens.expand(speech_embeds.shape[0], -1, -1)
186
+ query_output = self.speech_Qformer.bert(
187
+ query_embeds=query_tokens,
188
+ encoder_hidden_states=speech_embeds,
189
+ encoder_attention_mask=speech_atts,
190
+ return_dict=True,
191
+ )
192
+ speech_embeds = self.speech_llama_proj(query_output.last_hidden_state)
193
+ speech_embeds = speech_embeds.view(B, -1, speech_embeds.size(2)).contiguous()
194
+ speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long).to(speech_embeds.device)
195
+
196
+ # USER: <Speech>speech_embeds<Speech> prompt\nASSISTANT:
197
+ embed_tokens = self.llama_model.model.model.embed_tokens if self.lora else self.llama_model.model.embed_tokens
198
+ prompt_left, prompts_right = prompt_pattern.format(prompt).split('<SpeechHere>')
199
+ prompt_left_ids = self.llama_tokenizer(
200
+ prompt_left,
201
+ return_tensors="pt",
202
+ add_special_tokens=False
203
+ ).to(speech_embeds.device).input_ids
204
+ prompt_left_embeds = embed_tokens(prompt_left_ids)
205
+ prompt_right_ids = self.llama_tokenizer(
206
+ prompts_right,
207
+ return_tensors="pt",
208
+ add_special_tokens=False
209
+ ).to(speech_embeds.device).input_ids
210
+ prompt_right_embeds = embed_tokens(prompt_right_ids)
211
+
212
+ bos_embeds = self.llama_model.model.embed_tokens(
213
+ torch.ones(
214
+ [1, 1],
215
+ dtype=torch.long,
216
+ device=device,
217
+ ) * self.llama_tokenizer.bos_token_id
218
+ ) if not self.lora else self.llama_model.model.model.embed_tokens(
219
+ torch.ones(
220
+ [1, 1],
221
+ dtype=torch.long,
222
+ device=device,
223
+ ) * self.llama_tokenizer.bos_token_id
224
+ )
225
+
226
+ embeds = torch.cat([bos_embeds, prompt_left_embeds, speech_embeds, prompt_right_embeds], dim=1)
227
+ atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
228
+
229
+ # generate
230
+ output = self.llama_model.generate(
231
+ inputs_embeds=embeds,
232
+ max_length=max_length,
233
+ num_beams=num_beams,
234
+ do_sample=do_sample,
235
+ min_length=min_length,
236
+ top_p=top_p,
237
+ repetition_penalty=repetition_penalty,
238
+ length_penalty=length_penalty,
239
+ temperature=temperature,
240
+ attention_mask=atts,
241
+ bos_token_id=self.llama_tokenizer.bos_token_id,
242
+ eos_token_id=self.llama_tokenizer.eos_token_id,
243
+ pad_token_id=self.llama_tokenizer.pad_token_id
244
+ )
245
+
246
+ output_text = self.llama_tokenizer.batch_decode(output, add_special_tokens=False, skip_special_tokens=True)
247
+
248
+ return output_text
249
+
250
+ def init_speech_Qformer(self, num_query_token, speech_width, num_hidden_layers=2):
251
+ encoder_config = BertConfig()
252
+ encoder_config.num_hidden_layers = num_hidden_layers
253
+ encoder_config.encoder_width = speech_width
254
+ encoder_config.add_cross_attention = True
255
+ encoder_config.cross_attention_freq = 1
256
+ encoder_config.query_length = num_query_token
257
+ Qformer = BertLMHeadModel(config=encoder_config)
258
+ query_tokens = nn.Parameter(
259
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
260
+ )
261
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
262
+ return Qformer, query_tokens
qformer/LICENSE_Lavis ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2022 Salesforce, Inc.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9
+
10
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11
+
12
+ 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
qformer/LICENSE_MiniGPT4 ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright 2023 Deyao Zhu
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9
+
10
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11
+
12
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
qformer/LICENSE_VideoLlama ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2023, Multilingual NLP Team at Alibaba DAMO Academy
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
qformer/Qformer.py ADDED
@@ -0,0 +1,1217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from salesforce@LAVIS. Below is the original copyright:
3
+ * Copyright (c) 2023, salesforce.com, inc.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7
+ * By Junnan Li
8
+ * Based on huggingface code base
9
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
10
+ """
11
+
12
+ import math
13
+ import os
14
+ import warnings
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Dict, Any
17
+
18
+ import torch
19
+ from torch import Tensor, device, dtype, nn
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+ from torch.nn import CrossEntropyLoss
23
+ import torch.nn.functional as F
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.file_utils import (
27
+ ModelOutput,
28
+ )
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ BaseModelOutputWithPoolingAndCrossAttentions,
32
+ CausalLMOutputWithCrossAttentions,
33
+ MaskedLMOutput,
34
+ MultipleChoiceModelOutput,
35
+ NextSentencePredictorOutput,
36
+ QuestionAnsweringModelOutput,
37
+ SequenceClassifierOutput,
38
+ TokenClassifierOutput,
39
+ )
40
+ from transformers.modeling_utils import (
41
+ PreTrainedModel,
42
+ apply_chunking_to_forward,
43
+ find_pruneable_heads_and_indices,
44
+ prune_linear_layer,
45
+ )
46
+ from transformers.utils import logging
47
+ from transformers.models.bert.configuration_bert import BertConfig
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(
58
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
59
+ )
60
+ self.position_embeddings = nn.Embedding(
61
+ config.max_position_embeddings, config.hidden_size
62
+ )
63
+
64
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
65
+ # any TensorFlow checkpoint file
66
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
67
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
68
+
69
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
70
+ self.register_buffer(
71
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
72
+ )
73
+ self.position_embedding_type = getattr(
74
+ config, "position_embedding_type", "absolute"
75
+ )
76
+
77
+ self.config = config
78
+
79
+ def forward(
80
+ self,
81
+ input_ids=None,
82
+ position_ids=None,
83
+ query_embeds=None,
84
+ past_key_values_length=0,
85
+ ):
86
+ if input_ids is not None:
87
+ seq_length = input_ids.size()[1]
88
+ else:
89
+ seq_length = 0
90
+
91
+ if position_ids is None:
92
+ position_ids = self.position_ids[
93
+ :, past_key_values_length : seq_length + past_key_values_length
94
+ ].clone()
95
+
96
+ if input_ids is not None:
97
+ embeddings = self.word_embeddings(input_ids)
98
+ if self.position_embedding_type == "absolute":
99
+ position_embeddings = self.position_embeddings(position_ids)
100
+ embeddings = embeddings + position_embeddings
101
+
102
+ if query_embeds is not None:
103
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
104
+ else:
105
+ embeddings = query_embeds
106
+
107
+ embeddings = self.LayerNorm(embeddings)
108
+ embeddings = self.dropout(embeddings)
109
+ return embeddings
110
+
111
+
112
+ class BertSelfAttention(nn.Module):
113
+ def __init__(self, config, is_cross_attention):
114
+ super().__init__()
115
+ self.config = config
116
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
117
+ config, "embedding_size"
118
+ ):
119
+ raise ValueError(
120
+ "The hidden size (%d) is not a multiple of the number of attention "
121
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
122
+ )
123
+
124
+ self.num_attention_heads = config.num_attention_heads
125
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
126
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
127
+
128
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
129
+ if is_cross_attention:
130
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
131
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
132
+ else:
133
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
134
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
135
+
136
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
137
+ self.position_embedding_type = getattr(
138
+ config, "position_embedding_type", "absolute"
139
+ )
140
+ if (
141
+ self.position_embedding_type == "relative_key"
142
+ or self.position_embedding_type == "relative_key_query"
143
+ ):
144
+ self.max_position_embeddings = config.max_position_embeddings
145
+ self.distance_embedding = nn.Embedding(
146
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
147
+ )
148
+ self.save_attention = False
149
+
150
+ def save_attn_gradients(self, attn_gradients):
151
+ self.attn_gradients = attn_gradients
152
+
153
+ def get_attn_gradients(self):
154
+ return self.attn_gradients
155
+
156
+ def save_attention_map(self, attention_map):
157
+ self.attention_map = attention_map
158
+
159
+ def get_attention_map(self):
160
+ return self.attention_map
161
+
162
+ def transpose_for_scores(self, x):
163
+ new_x_shape = x.size()[:-1] + (
164
+ self.num_attention_heads,
165
+ self.attention_head_size,
166
+ )
167
+ x = x.view(*new_x_shape)
168
+ return x.permute(0, 2, 1, 3)
169
+
170
+ def forward(
171
+ self,
172
+ hidden_states,
173
+ attention_mask=None,
174
+ head_mask=None,
175
+ encoder_hidden_states=None,
176
+ encoder_attention_mask=None,
177
+ past_key_value=None,
178
+ output_attentions=False,
179
+ ):
180
+
181
+ # If this is instantiated as a cross-attention module, the keys
182
+ # and values come from an encoder; the attention mask needs to be
183
+ # such that the encoder's padding tokens are not attended to.
184
+ is_cross_attention = encoder_hidden_states is not None
185
+
186
+ if is_cross_attention:
187
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
188
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
189
+ attention_mask = encoder_attention_mask
190
+ elif past_key_value is not None:
191
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
192
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
193
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
194
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
195
+ else:
196
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
197
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
198
+
199
+ mixed_query_layer = self.query(hidden_states)
200
+
201
+ query_layer = self.transpose_for_scores(mixed_query_layer)
202
+
203
+ past_key_value = (key_layer, value_layer)
204
+
205
+ # Take the dot product between "query" and "key" to get the raw attention scores.
206
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
207
+
208
+ if (
209
+ self.position_embedding_type == "relative_key"
210
+ or self.position_embedding_type == "relative_key_query"
211
+ ):
212
+ seq_length = hidden_states.size()[1]
213
+ position_ids_l = torch.arange(
214
+ seq_length, dtype=torch.long, device=hidden_states.device
215
+ ).view(-1, 1)
216
+ position_ids_r = torch.arange(
217
+ seq_length, dtype=torch.long, device=hidden_states.device
218
+ ).view(1, -1)
219
+ distance = position_ids_l - position_ids_r
220
+ positional_embedding = self.distance_embedding(
221
+ distance + self.max_position_embeddings - 1
222
+ )
223
+ positional_embedding = positional_embedding.to(
224
+ dtype=query_layer.dtype
225
+ ) # fp16 compatibility
226
+
227
+ if self.position_embedding_type == "relative_key":
228
+ relative_position_scores = torch.einsum(
229
+ "bhld,lrd->bhlr", query_layer, positional_embedding
230
+ )
231
+ attention_scores = attention_scores + relative_position_scores
232
+ elif self.position_embedding_type == "relative_key_query":
233
+ relative_position_scores_query = torch.einsum(
234
+ "bhld,lrd->bhlr", query_layer, positional_embedding
235
+ )
236
+ relative_position_scores_key = torch.einsum(
237
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
238
+ )
239
+ attention_scores = (
240
+ attention_scores
241
+ + relative_position_scores_query
242
+ + relative_position_scores_key
243
+ )
244
+
245
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
246
+ if attention_mask is not None:
247
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
248
+ attention_scores = attention_scores + attention_mask
249
+
250
+ # Normalize the attention scores to probabilities.
251
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
252
+
253
+ if is_cross_attention and self.save_attention:
254
+ self.save_attention_map(attention_probs)
255
+ attention_probs.register_hook(self.save_attn_gradients)
256
+
257
+ # This is actually dropping out entire tokens to attend to, which might
258
+ # seem a bit unusual, but is taken from the original Transformer paper.
259
+ attention_probs_dropped = self.dropout(attention_probs)
260
+
261
+ # Mask heads if we want to
262
+ if head_mask is not None:
263
+ attention_probs_dropped = attention_probs_dropped * head_mask
264
+
265
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
266
+
267
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
268
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
269
+ context_layer = context_layer.view(*new_context_layer_shape)
270
+
271
+ outputs = (
272
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
273
+ )
274
+
275
+ outputs = outputs + (past_key_value,)
276
+ return outputs
277
+
278
+
279
+ class BertSelfOutput(nn.Module):
280
+ def __init__(self, config):
281
+ super().__init__()
282
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
283
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
284
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
285
+
286
+ def forward(self, hidden_states, input_tensor):
287
+ hidden_states = self.dense(hidden_states)
288
+ hidden_states = self.dropout(hidden_states)
289
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
290
+ return hidden_states
291
+
292
+
293
+ class BertAttention(nn.Module):
294
+ def __init__(self, config, is_cross_attention=False):
295
+ super().__init__()
296
+ self.self = BertSelfAttention(config, is_cross_attention)
297
+ self.output = BertSelfOutput(config)
298
+ self.pruned_heads = set()
299
+
300
+ def prune_heads(self, heads):
301
+ if len(heads) == 0:
302
+ return
303
+ heads, index = find_pruneable_heads_and_indices(
304
+ heads,
305
+ self.self.num_attention_heads,
306
+ self.self.attention_head_size,
307
+ self.pruned_heads,
308
+ )
309
+
310
+ # Prune linear layers
311
+ self.self.query = prune_linear_layer(self.self.query, index)
312
+ self.self.key = prune_linear_layer(self.self.key, index)
313
+ self.self.value = prune_linear_layer(self.self.value, index)
314
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
315
+
316
+ # Update hyper params and store pruned heads
317
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
318
+ self.self.all_head_size = (
319
+ self.self.attention_head_size * self.self.num_attention_heads
320
+ )
321
+ self.pruned_heads = self.pruned_heads.union(heads)
322
+
323
+ def forward(
324
+ self,
325
+ hidden_states,
326
+ attention_mask=None,
327
+ head_mask=None,
328
+ encoder_hidden_states=None,
329
+ encoder_attention_mask=None,
330
+ past_key_value=None,
331
+ output_attentions=False,
332
+ ):
333
+ self_outputs = self.self(
334
+ hidden_states,
335
+ attention_mask,
336
+ head_mask,
337
+ encoder_hidden_states,
338
+ encoder_attention_mask,
339
+ past_key_value,
340
+ output_attentions,
341
+ )
342
+ attention_output = self.output(self_outputs[0], hidden_states)
343
+
344
+ outputs = (attention_output,) + self_outputs[
345
+ 1:
346
+ ] # add attentions if we output them
347
+ return outputs
348
+
349
+
350
+ class BertIntermediate(nn.Module):
351
+ def __init__(self, config):
352
+ super().__init__()
353
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
354
+ if isinstance(config.hidden_act, str):
355
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
356
+ else:
357
+ self.intermediate_act_fn = config.hidden_act
358
+
359
+ def forward(self, hidden_states):
360
+ hidden_states = self.dense(hidden_states)
361
+ hidden_states = self.intermediate_act_fn(hidden_states)
362
+ return hidden_states
363
+
364
+
365
+ class BertOutput(nn.Module):
366
+ def __init__(self, config):
367
+ super().__init__()
368
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
369
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
370
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
371
+
372
+ def forward(self, hidden_states, input_tensor):
373
+ hidden_states = self.dense(hidden_states)
374
+ hidden_states = self.dropout(hidden_states)
375
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
376
+ return hidden_states
377
+
378
+
379
+ class BertLayer(nn.Module):
380
+ def __init__(self, config, layer_num):
381
+ super().__init__()
382
+ self.config = config
383
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
384
+ self.seq_len_dim = 1
385
+ self.attention = BertAttention(config)
386
+ self.layer_num = layer_num
387
+ if (
388
+ self.config.add_cross_attention
389
+ and layer_num % self.config.cross_attention_freq == 0
390
+ ):
391
+ self.crossattention = BertAttention(
392
+ config, is_cross_attention=self.config.add_cross_attention
393
+ )
394
+ self.has_cross_attention = True
395
+ else:
396
+ self.has_cross_attention = False
397
+ self.intermediate = BertIntermediate(config)
398
+ self.output = BertOutput(config)
399
+
400
+ self.intermediate_query = BertIntermediate(config)
401
+ self.output_query = BertOutput(config)
402
+
403
+ def forward(
404
+ self,
405
+ hidden_states,
406
+ attention_mask=None,
407
+ head_mask=None,
408
+ encoder_hidden_states=None,
409
+ encoder_attention_mask=None,
410
+ past_key_value=None,
411
+ output_attentions=False,
412
+ query_length=0,
413
+ ):
414
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
415
+ self_attn_past_key_value = (
416
+ past_key_value[:2] if past_key_value is not None else None
417
+ )
418
+ self_attention_outputs = self.attention(
419
+ hidden_states,
420
+ attention_mask,
421
+ head_mask,
422
+ output_attentions=output_attentions,
423
+ past_key_value=self_attn_past_key_value,
424
+ )
425
+ attention_output = self_attention_outputs[0]
426
+ outputs = self_attention_outputs[1:-1]
427
+
428
+ present_key_value = self_attention_outputs[-1]
429
+
430
+ if query_length > 0:
431
+ query_attention_output = attention_output[:, :query_length, :]
432
+
433
+ if self.has_cross_attention:
434
+ assert (
435
+ encoder_hidden_states is not None
436
+ ), "encoder_hidden_states must be given for cross-attention layers"
437
+ cross_attention_outputs = self.crossattention(
438
+ query_attention_output,
439
+ attention_mask,
440
+ head_mask,
441
+ encoder_hidden_states,
442
+ encoder_attention_mask,
443
+ output_attentions=output_attentions,
444
+ )
445
+ query_attention_output = cross_attention_outputs[0]
446
+ outputs = (
447
+ outputs + cross_attention_outputs[1:-1]
448
+ ) # add cross attentions if we output attention weights
449
+
450
+ layer_output = apply_chunking_to_forward(
451
+ self.feed_forward_chunk_query,
452
+ self.chunk_size_feed_forward,
453
+ self.seq_len_dim,
454
+ query_attention_output,
455
+ )
456
+ if attention_output.shape[1] > query_length:
457
+ layer_output_text = apply_chunking_to_forward(
458
+ self.feed_forward_chunk,
459
+ self.chunk_size_feed_forward,
460
+ self.seq_len_dim,
461
+ attention_output[:, query_length:, :],
462
+ )
463
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
464
+ else:
465
+ layer_output = apply_chunking_to_forward(
466
+ self.feed_forward_chunk,
467
+ self.chunk_size_feed_forward,
468
+ self.seq_len_dim,
469
+ attention_output,
470
+ )
471
+ outputs = (layer_output,) + outputs
472
+
473
+ outputs = outputs + (present_key_value,)
474
+
475
+ return outputs
476
+
477
+ def feed_forward_chunk(self, attention_output):
478
+ intermediate_output = self.intermediate(attention_output)
479
+ layer_output = self.output(intermediate_output, attention_output)
480
+ return layer_output
481
+
482
+ def feed_forward_chunk_query(self, attention_output):
483
+ intermediate_output = self.intermediate_query(attention_output)
484
+ layer_output = self.output_query(intermediate_output, attention_output)
485
+ return layer_output
486
+
487
+
488
+ class BertEncoder(nn.Module):
489
+ def __init__(self, config):
490
+ super().__init__()
491
+ self.config = config
492
+ self.layer = nn.ModuleList(
493
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
494
+ )
495
+
496
+ def forward(
497
+ self,
498
+ hidden_states,
499
+ attention_mask=None,
500
+ head_mask=None,
501
+ encoder_hidden_states=None,
502
+ encoder_attention_mask=None,
503
+ past_key_values=None,
504
+ use_cache=None,
505
+ output_attentions=False,
506
+ output_hidden_states=False,
507
+ return_dict=True,
508
+ query_length=0,
509
+ ):
510
+ all_hidden_states = () if output_hidden_states else None
511
+ all_self_attentions = () if output_attentions else None
512
+ all_cross_attentions = (
513
+ () if output_attentions and self.config.add_cross_attention else None
514
+ )
515
+
516
+ next_decoder_cache = () if use_cache else None
517
+
518
+ for i in range(self.config.num_hidden_layers):
519
+ layer_module = self.layer[i]
520
+ if output_hidden_states:
521
+ all_hidden_states = all_hidden_states + (hidden_states,)
522
+
523
+ layer_head_mask = head_mask[i] if head_mask is not None else None
524
+ past_key_value = past_key_values[i] if past_key_values is not None else None
525
+
526
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
527
+
528
+ if use_cache:
529
+ logger.warn(
530
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
531
+ )
532
+ use_cache = False
533
+
534
+ def create_custom_forward(module):
535
+ def custom_forward(*inputs):
536
+ return module(
537
+ *inputs, past_key_value, output_attentions, query_length
538
+ )
539
+
540
+ return custom_forward
541
+
542
+ layer_outputs = torch.utils.checkpoint.checkpoint(
543
+ create_custom_forward(layer_module),
544
+ hidden_states,
545
+ attention_mask,
546
+ layer_head_mask,
547
+ encoder_hidden_states,
548
+ encoder_attention_mask,
549
+ )
550
+ else:
551
+ layer_outputs = layer_module(
552
+ hidden_states,
553
+ attention_mask,
554
+ layer_head_mask,
555
+ encoder_hidden_states,
556
+ encoder_attention_mask,
557
+ past_key_value,
558
+ output_attentions,
559
+ query_length,
560
+ )
561
+
562
+ hidden_states = layer_outputs[0]
563
+ if use_cache:
564
+ next_decoder_cache += (layer_outputs[-1],)
565
+ if output_attentions:
566
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
567
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
568
+
569
+ if output_hidden_states:
570
+ all_hidden_states = all_hidden_states + (hidden_states,)
571
+
572
+ if not return_dict:
573
+ return tuple(
574
+ v
575
+ for v in [
576
+ hidden_states,
577
+ next_decoder_cache,
578
+ all_hidden_states,
579
+ all_self_attentions,
580
+ all_cross_attentions,
581
+ ]
582
+ if v is not None
583
+ )
584
+ return BaseModelOutputWithPastAndCrossAttentions(
585
+ last_hidden_state=hidden_states,
586
+ past_key_values=next_decoder_cache,
587
+ hidden_states=all_hidden_states,
588
+ attentions=all_self_attentions,
589
+ cross_attentions=all_cross_attentions,
590
+ )
591
+
592
+
593
+ class BertPooler(nn.Module):
594
+ def __init__(self, config):
595
+ super().__init__()
596
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
597
+ self.activation = nn.Tanh()
598
+
599
+ def forward(self, hidden_states):
600
+ # We "pool" the model by simply taking the hidden state corresponding
601
+ # to the first token.
602
+ first_token_tensor = hidden_states[:, 0]
603
+ pooled_output = self.dense(first_token_tensor)
604
+ pooled_output = self.activation(pooled_output)
605
+ return pooled_output
606
+
607
+
608
+ class BertPredictionHeadTransform(nn.Module):
609
+ def __init__(self, config):
610
+ super().__init__()
611
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
612
+ if isinstance(config.hidden_act, str):
613
+ self.transform_act_fn = ACT2FN[config.hidden_act]
614
+ else:
615
+ self.transform_act_fn = config.hidden_act
616
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
617
+
618
+ def forward(self, hidden_states):
619
+ hidden_states = self.dense(hidden_states)
620
+ hidden_states = self.transform_act_fn(hidden_states)
621
+ hidden_states = self.LayerNorm(hidden_states)
622
+ return hidden_states
623
+
624
+
625
+ class BertLMPredictionHead(nn.Module):
626
+ def __init__(self, config):
627
+ super().__init__()
628
+ self.transform = BertPredictionHeadTransform(config)
629
+
630
+ # The output weights are the same as the input embeddings, but there is
631
+ # an output-only bias for each token.
632
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
633
+
634
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
635
+
636
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
637
+ self.decoder.bias = self.bias
638
+
639
+ def forward(self, hidden_states):
640
+ hidden_states = self.transform(hidden_states)
641
+ hidden_states = self.decoder(hidden_states)
642
+ return hidden_states
643
+
644
+
645
+ class BertOnlyMLMHead(nn.Module):
646
+ def __init__(self, config):
647
+ super().__init__()
648
+ self.predictions = BertLMPredictionHead(config)
649
+
650
+ def forward(self, sequence_output):
651
+ prediction_scores = self.predictions(sequence_output)
652
+ return prediction_scores
653
+
654
+
655
+ class BertPreTrainedModel(PreTrainedModel):
656
+ """
657
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
658
+ models.
659
+ """
660
+
661
+ config_class = BertConfig
662
+ base_model_prefix = "bert"
663
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
664
+
665
+ def _init_weights(self, module):
666
+ """Initialize the weights"""
667
+ if isinstance(module, (nn.Linear, nn.Embedding)):
668
+ # Slightly different from the TF version which uses truncated_normal for initialization
669
+ # cf https://github.com/pytorch/pytorch/pull/5617
670
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
671
+ elif isinstance(module, nn.LayerNorm):
672
+ module.bias.data.zero_()
673
+ module.weight.data.fill_(1.0)
674
+ if isinstance(module, nn.Linear) and module.bias is not None:
675
+ module.bias.data.zero_()
676
+
677
+
678
+ class BertModel(BertPreTrainedModel):
679
+ """
680
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
681
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
682
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
683
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
684
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
685
+ input to the forward pass.
686
+ """
687
+
688
+ def __init__(self, config, add_pooling_layer=False):
689
+ super().__init__(config)
690
+ self.config = config
691
+
692
+ self.embeddings = BertEmbeddings(config)
693
+
694
+ self.encoder = BertEncoder(config)
695
+
696
+ self.pooler = BertPooler(config) if add_pooling_layer else None
697
+
698
+ self.init_weights()
699
+
700
+ def get_input_embeddings(self):
701
+ return self.embeddings.word_embeddings
702
+
703
+ def set_input_embeddings(self, value):
704
+ self.embeddings.word_embeddings = value
705
+
706
+ def _prune_heads(self, heads_to_prune):
707
+ """
708
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
709
+ class PreTrainedModel
710
+ """
711
+ for layer, heads in heads_to_prune.items():
712
+ self.encoder.layer[layer].attention.prune_heads(heads)
713
+
714
+ def get_extended_attention_mask(
715
+ self,
716
+ attention_mask: Tensor,
717
+ input_shape: Tuple[int],
718
+ device: device,
719
+ is_decoder: bool,
720
+ has_query: bool = False,
721
+ ) -> Tensor:
722
+ """
723
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
724
+
725
+ Arguments:
726
+ attention_mask (:obj:`torch.Tensor`):
727
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
728
+ input_shape (:obj:`Tuple[int]`):
729
+ The shape of the input to the model.
730
+ device: (:obj:`torch.device`):
731
+ The device of the input to the model.
732
+
733
+ Returns:
734
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
735
+ """
736
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
737
+ # ourselves in which case we just need to make it broadcastable to all heads.
738
+ if attention_mask.dim() == 3:
739
+ extended_attention_mask = attention_mask[:, None, :, :]
740
+ elif attention_mask.dim() == 2:
741
+ # Provided a padding mask of dimensions [batch_size, seq_length]
742
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
743
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
744
+ if is_decoder:
745
+ batch_size, seq_length = input_shape
746
+
747
+ seq_ids = torch.arange(seq_length, device=device)
748
+ causal_mask = (
749
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
750
+ <= seq_ids[None, :, None]
751
+ )
752
+
753
+ # add a prefix ones mask to the causal mask
754
+ # causal and attention masks must have same type with pytorch version < 1.3
755
+ causal_mask = causal_mask.to(attention_mask.dtype)
756
+
757
+ if causal_mask.shape[1] < attention_mask.shape[1]:
758
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
759
+ if has_query: # UniLM style attention mask
760
+ causal_mask = torch.cat(
761
+ [
762
+ torch.zeros(
763
+ (batch_size, prefix_seq_len, seq_length),
764
+ device=device,
765
+ dtype=causal_mask.dtype,
766
+ ),
767
+ causal_mask,
768
+ ],
769
+ axis=1,
770
+ )
771
+ causal_mask = torch.cat(
772
+ [
773
+ torch.ones(
774
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
775
+ device=device,
776
+ dtype=causal_mask.dtype,
777
+ ),
778
+ causal_mask,
779
+ ],
780
+ axis=-1,
781
+ )
782
+ extended_attention_mask = (
783
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
784
+ )
785
+ else:
786
+ extended_attention_mask = attention_mask[:, None, None, :]
787
+ else:
788
+ raise ValueError(
789
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
790
+ input_shape, attention_mask.shape
791
+ )
792
+ )
793
+
794
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
795
+ # masked positions, this operation will create a tensor which is 0.0 for
796
+ # positions we want to attend and -10000.0 for masked positions.
797
+ # Since we are adding it to the raw scores before the softmax, this is
798
+ # effectively the same as removing these entirely.
799
+ extended_attention_mask = extended_attention_mask.to(
800
+ dtype=self.dtype
801
+ ) # fp16 compatibility
802
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
803
+ return extended_attention_mask
804
+
805
+ def forward(
806
+ self,
807
+ input_ids=None,
808
+ attention_mask=None,
809
+ position_ids=None,
810
+ head_mask=None,
811
+ query_embeds=None,
812
+ encoder_hidden_states=None,
813
+ encoder_attention_mask=None,
814
+ past_key_values=None,
815
+ use_cache=None,
816
+ output_attentions=None,
817
+ output_hidden_states=None,
818
+ return_dict=None,
819
+ is_decoder=False,
820
+ ):
821
+ r"""
822
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
823
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
824
+ the model is configured as a decoder.
825
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
826
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
827
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
828
+ - 1 for tokens that are **not masked**,
829
+ - 0 for tokens that are **masked**.
830
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
831
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
832
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
833
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
834
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
835
+ use_cache (:obj:`bool`, `optional`):
836
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
837
+ decoding (see :obj:`past_key_values`).
838
+ """
839
+ output_attentions = (
840
+ output_attentions
841
+ if output_attentions is not None
842
+ else self.config.output_attentions
843
+ )
844
+ output_hidden_states = (
845
+ output_hidden_states
846
+ if output_hidden_states is not None
847
+ else self.config.output_hidden_states
848
+ )
849
+ return_dict = (
850
+ return_dict if return_dict is not None else self.config.use_return_dict
851
+ )
852
+
853
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
854
+
855
+ if input_ids is None:
856
+ assert (
857
+ query_embeds is not None
858
+ ), "You have to specify query_embeds when input_ids is None"
859
+
860
+ # past_key_values_length
861
+ past_key_values_length = (
862
+ past_key_values[0][0].shape[2] - self.config.query_length
863
+ if past_key_values is not None
864
+ else 0
865
+ )
866
+
867
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
868
+
869
+ embedding_output = self.embeddings(
870
+ input_ids=input_ids,
871
+ position_ids=position_ids,
872
+ query_embeds=query_embeds,
873
+ past_key_values_length=past_key_values_length,
874
+ )
875
+
876
+ input_shape = embedding_output.size()[:-1]
877
+ batch_size, seq_length = input_shape
878
+ device = embedding_output.device
879
+
880
+ if attention_mask is None:
881
+ attention_mask = torch.ones(
882
+ ((batch_size, seq_length + past_key_values_length)), device=device
883
+ )
884
+
885
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
886
+ # ourselves in which case we just need to make it broadcastable to all heads.
887
+ if is_decoder:
888
+ extended_attention_mask = self.get_extended_attention_mask(
889
+ attention_mask,
890
+ input_ids.shape,
891
+ device,
892
+ is_decoder,
893
+ has_query=(query_embeds is not None),
894
+ )
895
+ else:
896
+ extended_attention_mask = self.get_extended_attention_mask(
897
+ attention_mask, input_shape, device, is_decoder
898
+ )
899
+
900
+ # If a 2D or 3D attention mask is provided for the cross-attention
901
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
902
+ if encoder_hidden_states is not None:
903
+ if type(encoder_hidden_states) == list:
904
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
905
+ 0
906
+ ].size()
907
+ else:
908
+ (
909
+ encoder_batch_size,
910
+ encoder_sequence_length,
911
+ _,
912
+ ) = encoder_hidden_states.size()
913
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
914
+
915
+ if type(encoder_attention_mask) == list:
916
+ encoder_extended_attention_mask = [
917
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
918
+ ]
919
+ elif encoder_attention_mask is None:
920
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
921
+ encoder_extended_attention_mask = self.invert_attention_mask(
922
+ encoder_attention_mask
923
+ )
924
+ else:
925
+ encoder_extended_attention_mask = self.invert_attention_mask(
926
+ encoder_attention_mask
927
+ )
928
+ else:
929
+ encoder_extended_attention_mask = None
930
+
931
+ # Prepare head mask if needed
932
+ # 1.0 in head_mask indicate we keep the head
933
+ # attention_probs has shape bsz x n_heads x N x N
934
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
935
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
936
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
937
+
938
+ encoder_outputs = self.encoder(
939
+ embedding_output,
940
+ attention_mask=extended_attention_mask,
941
+ head_mask=head_mask,
942
+ encoder_hidden_states=encoder_hidden_states,
943
+ encoder_attention_mask=encoder_extended_attention_mask,
944
+ past_key_values=past_key_values,
945
+ use_cache=use_cache,
946
+ output_attentions=output_attentions,
947
+ output_hidden_states=output_hidden_states,
948
+ return_dict=return_dict,
949
+ query_length=query_length,
950
+ )
951
+ sequence_output = encoder_outputs[0]
952
+ pooled_output = (
953
+ self.pooler(sequence_output) if self.pooler is not None else None
954
+ )
955
+
956
+ if not return_dict:
957
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
958
+
959
+ return BaseModelOutputWithPoolingAndCrossAttentions(
960
+ last_hidden_state=sequence_output,
961
+ pooler_output=pooled_output,
962
+ past_key_values=encoder_outputs.past_key_values,
963
+ hidden_states=encoder_outputs.hidden_states,
964
+ attentions=encoder_outputs.attentions,
965
+ cross_attentions=encoder_outputs.cross_attentions,
966
+ )
967
+
968
+
969
+ class BertLMHeadModel(BertPreTrainedModel):
970
+
971
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
972
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
973
+
974
+ def __init__(self, config):
975
+ super().__init__(config)
976
+
977
+ self.bert = BertModel(config, add_pooling_layer=False)
978
+ self.cls = BertOnlyMLMHead(config)
979
+
980
+ self.init_weights()
981
+
982
+ def get_output_embeddings(self):
983
+ return self.cls.predictions.decoder
984
+
985
+ def set_output_embeddings(self, new_embeddings):
986
+ self.cls.predictions.decoder = new_embeddings
987
+
988
+ def forward(
989
+ self,
990
+ input_ids=None,
991
+ attention_mask=None,
992
+ position_ids=None,
993
+ head_mask=None,
994
+ query_embeds=None,
995
+ encoder_hidden_states=None,
996
+ encoder_attention_mask=None,
997
+ labels=None,
998
+ past_key_values=None,
999
+ use_cache=True,
1000
+ output_attentions=None,
1001
+ output_hidden_states=None,
1002
+ return_dict=None,
1003
+ return_logits=False,
1004
+ is_decoder=True,
1005
+ reduction="mean",
1006
+ ):
1007
+ r"""
1008
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1009
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1010
+ the model is configured as a decoder.
1011
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1012
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1013
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1014
+ - 1 for tokens that are **not masked**,
1015
+ - 0 for tokens that are **masked**.
1016
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1017
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1018
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1019
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1020
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1021
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1022
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1023
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1024
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1025
+ use_cache (:obj:`bool`, `optional`):
1026
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1027
+ decoding (see :obj:`past_key_values`).
1028
+ Returns:
1029
+ Example::
1030
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1031
+ >>> import torch
1032
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1033
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1034
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1035
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1036
+ >>> outputs = model(**inputs)
1037
+ >>> prediction_logits = outputs.logits
1038
+ """
1039
+ return_dict = (
1040
+ return_dict if return_dict is not None else self.config.use_return_dict
1041
+ )
1042
+ if labels is not None:
1043
+ use_cache = False
1044
+ if past_key_values is not None:
1045
+ query_embeds = None
1046
+
1047
+ outputs = self.bert(
1048
+ input_ids,
1049
+ attention_mask=attention_mask,
1050
+ position_ids=position_ids,
1051
+ head_mask=head_mask,
1052
+ query_embeds=query_embeds,
1053
+ encoder_hidden_states=encoder_hidden_states,
1054
+ encoder_attention_mask=encoder_attention_mask,
1055
+ past_key_values=past_key_values,
1056
+ use_cache=use_cache,
1057
+ output_attentions=output_attentions,
1058
+ output_hidden_states=output_hidden_states,
1059
+ return_dict=return_dict,
1060
+ is_decoder=is_decoder,
1061
+ )
1062
+
1063
+ sequence_output = outputs[0]
1064
+ if query_embeds is not None:
1065
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1066
+
1067
+ prediction_scores = self.cls(sequence_output)
1068
+
1069
+ if return_logits:
1070
+ return prediction_scores[:, :-1, :].contiguous()
1071
+
1072
+ lm_loss = None
1073
+ if labels is not None:
1074
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1075
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1076
+ labels = labels[:, 1:].contiguous()
1077
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1078
+ lm_loss = loss_fct(
1079
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1080
+ labels.view(-1),
1081
+ )
1082
+ if reduction == "none":
1083
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1084
+
1085
+ if not return_dict:
1086
+ output = (prediction_scores,) + outputs[2:]
1087
+ return ((lm_loss,) + output) if lm_loss is not None else output
1088
+
1089
+ return CausalLMOutputWithCrossAttentions(
1090
+ loss=lm_loss,
1091
+ logits=prediction_scores,
1092
+ past_key_values=outputs.past_key_values,
1093
+ hidden_states=outputs.hidden_states,
1094
+ attentions=outputs.attentions,
1095
+ cross_attentions=outputs.cross_attentions,
1096
+ )
1097
+
1098
+ def prepare_inputs_for_generation(
1099
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1100
+ ):
1101
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1102
+ if attention_mask is None:
1103
+ attention_mask = input_ids.new_ones(input_ids.shape)
1104
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1105
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1106
+
1107
+ # cut decoder_input_ids if past is used
1108
+ if past is not None:
1109
+ input_ids = input_ids[:, -1:]
1110
+
1111
+ return {
1112
+ "input_ids": input_ids,
1113
+ "query_embeds": query_embeds,
1114
+ "attention_mask": attention_mask,
1115
+ "past_key_values": past,
1116
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1117
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1118
+ "is_decoder": True,
1119
+ }
1120
+
1121
+ def _reorder_cache(self, past, beam_idx):
1122
+ reordered_past = ()
1123
+ for layer_past in past:
1124
+ reordered_past += (
1125
+ tuple(
1126
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1127
+ ),
1128
+ )
1129
+ return reordered_past
1130
+
1131
+
1132
+ class BertForMaskedLM(BertPreTrainedModel):
1133
+
1134
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1135
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1136
+
1137
+ def __init__(self, config):
1138
+ super().__init__(config)
1139
+
1140
+ self.bert = BertModel(config, add_pooling_layer=False)
1141
+ self.cls = BertOnlyMLMHead(config)
1142
+
1143
+ self.init_weights()
1144
+
1145
+ def get_output_embeddings(self):
1146
+ return self.cls.predictions.decoder
1147
+
1148
+ def set_output_embeddings(self, new_embeddings):
1149
+ self.cls.predictions.decoder = new_embeddings
1150
+
1151
+ def forward(
1152
+ self,
1153
+ input_ids=None,
1154
+ attention_mask=None,
1155
+ position_ids=None,
1156
+ head_mask=None,
1157
+ query_embeds=None,
1158
+ encoder_hidden_states=None,
1159
+ encoder_attention_mask=None,
1160
+ labels=None,
1161
+ output_attentions=None,
1162
+ output_hidden_states=None,
1163
+ return_dict=None,
1164
+ return_logits=False,
1165
+ is_decoder=False,
1166
+ ):
1167
+ r"""
1168
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1169
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1170
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1171
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1172
+ """
1173
+
1174
+ return_dict = (
1175
+ return_dict if return_dict is not None else self.config.use_return_dict
1176
+ )
1177
+
1178
+ outputs = self.bert(
1179
+ input_ids,
1180
+ attention_mask=attention_mask,
1181
+ position_ids=position_ids,
1182
+ head_mask=head_mask,
1183
+ query_embeds=query_embeds,
1184
+ encoder_hidden_states=encoder_hidden_states,
1185
+ encoder_attention_mask=encoder_attention_mask,
1186
+ output_attentions=output_attentions,
1187
+ output_hidden_states=output_hidden_states,
1188
+ return_dict=return_dict,
1189
+ is_decoder=is_decoder,
1190
+ )
1191
+
1192
+ if query_embeds is not None:
1193
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1194
+ prediction_scores = self.cls(sequence_output)
1195
+
1196
+ if return_logits:
1197
+ return prediction_scores
1198
+
1199
+ masked_lm_loss = None
1200
+ if labels is not None:
1201
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1202
+ masked_lm_loss = loss_fct(
1203
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1204
+ )
1205
+
1206
+ if not return_dict:
1207
+ output = (prediction_scores,) + outputs[2:]
1208
+ return (
1209
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1210
+ )
1211
+
1212
+ return MaskedLMOutput(
1213
+ loss=masked_lm_loss,
1214
+ logits=prediction_scores,
1215
+ hidden_states=outputs.hidden_states,
1216
+ attentions=outputs.attentions,
1217
+ )
qformer/__pycache__/Qformer.cpython-310.pyc ADDED
Binary file (30.7 kB). View file
 
qformer/__pycache__/Qformer.cpython-39.pyc ADDED
Binary file (30.9 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchaudio==2.0.2
3
+ peft==0.3.0
4
+ soundfile
5
+ librosa
6
+ transformers==4.41.0
7
+ sentencepiece==0.1.97
8
+ accelerate==0.20.3
9
+ bitsandbytes==0.35.0
10
+ gradio==3.23.0
resource/audio_demo/duck.wav ADDED
Binary file (640 kB). View file
 
resource/audio_demo/excitement.wav ADDED
Binary file (40.4 kB). View file
 
resource/audio_demo/gunshots.wav ADDED
Binary file (320 kB). View file
 
resource/audio_demo/mountain.wav ADDED
Binary file (79.1 kB). View file
 
resource/audio_demo/music.wav ADDED
Binary file (639 kB). View file
 
resource/response_demo/aac.png ADDED
resource/response_demo/aed.png ADDED
resource/response_demo/asr.png ADDED
resource/response_demo/emo.png ADDED
resource/response_demo/jsac.png ADDED
resource/response_demo/lyrics.png ADDED
resource/response_demo/mc.png ADDED
resource/response_demo/memo.png ADDED
resource/response_demo/pr.png ADDED
resource/response_demo/sac.png ADDED
resource/response_demo/sq.png ADDED
resource/response_demo/sr.png ADDED
resource/response_demo/story.png ADDED
resource/response_demo/title.png ADDED