Lakoc commited on
Commit
1b9475a
1 Parent(s): 4d69b96

Upload JointCTCAttentionEncoderDecoder

Browse files
auto_wrappers.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+
4
+ from transformers import AutoConfig, AutoModelForCTC, PretrainedConfig
5
+ from transformers.dynamic_module_utils import (
6
+ get_class_from_dynamic_module,
7
+ resolve_trust_remote_code,
8
+ )
9
+ from transformers.models.auto.auto_factory import _get_model_class
10
+
11
+ from .extractors import Conv2dFeatureExtractor
12
+
13
+
14
+ class FeatureExtractionInitModifier(type):
15
+ def __new__(cls, name, bases, dct):
16
+ # Create the class using the original definition
17
+ new_cls = super().__new__(cls, name, bases, dct)
18
+
19
+ # Save the original __init__ method
20
+ original_init = new_cls.__init__
21
+
22
+ # Modify the __init__ method dynamically
23
+ def new_init(self, *args, **kwargs):
24
+ original_init(self, *args, **kwargs)
25
+ if self.config.expect_2d_input:
26
+ getattr(self, self.base_model_prefix).feature_extractor = Conv2dFeatureExtractor(self.config)
27
+
28
+ # Replace the __init__ method with the modified version
29
+ new_cls.__init__ = new_init
30
+
31
+ return new_cls
32
+
33
+
34
+ class CustomAutoModelForCTC(AutoModelForCTC):
35
+ @classmethod
36
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
37
+ config = kwargs.pop("config", None)
38
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
39
+ kwargs["_from_auto"] = True
40
+ hub_kwargs_names = [
41
+ "cache_dir",
42
+ "code_revision",
43
+ "force_download",
44
+ "local_files_only",
45
+ "proxies",
46
+ "resume_download",
47
+ "revision",
48
+ "subfolder",
49
+ "use_auth_token",
50
+ ]
51
+ hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
52
+ if not isinstance(config, PretrainedConfig):
53
+ kwargs_orig = copy.deepcopy(kwargs)
54
+ # ensure not to pollute the config object with torch_dtype="auto" - since it's
55
+ # meaningless in the context of the config object - torch.dtype values are acceptable
56
+ if kwargs.get("torch_dtype", None) == "auto":
57
+ _ = kwargs.pop("torch_dtype")
58
+
59
+ config, kwargs = AutoConfig.from_pretrained(
60
+ pretrained_model_name_or_path,
61
+ return_unused_kwargs=True,
62
+ trust_remote_code=trust_remote_code,
63
+ **hub_kwargs,
64
+ **kwargs,
65
+ )
66
+
67
+ # if torch_dtype=auto was passed here, ensure to pass it on
68
+ if kwargs_orig.get("torch_dtype", None) == "auto":
69
+ kwargs["torch_dtype"] = "auto"
70
+
71
+ has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
72
+ has_local_code = type(config) in cls._model_mapping.keys()
73
+ trust_remote_code = resolve_trust_remote_code(
74
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
75
+ )
76
+ if has_remote_code and trust_remote_code:
77
+ class_ref = config.auto_map[cls.__name__]
78
+ model_class = get_class_from_dynamic_module(
79
+ class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs
80
+ )
81
+ model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {})
82
+ _ = hub_kwargs.pop("code_revision", None)
83
+ if os.path.isdir(pretrained_model_name_or_path):
84
+ model_class.register_for_auto_class(cls.__name__)
85
+ else:
86
+ cls.register(config.__class__, model_class, exist_ok=True)
87
+ return model_class.from_pretrained(
88
+ pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
89
+ )
90
+ elif type(config) in cls._model_mapping.keys():
91
+ model_class = _get_model_class(config, cls._model_mapping)
92
+ model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {})
93
+ return model_class.from_pretrained(
94
+ pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
95
+ )
96
+ raise ValueError(
97
+ f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
98
+ f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
99
+ )
100
+
101
+ @classmethod
102
+ def from_config(cls, config, **kwargs):
103
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
104
+ has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
105
+ has_local_code = type(config) in cls._model_mapping.keys()
106
+ trust_remote_code = resolve_trust_remote_code(
107
+ trust_remote_code, config._name_or_path, has_local_code, has_remote_code
108
+ )
109
+
110
+ if has_remote_code and trust_remote_code:
111
+ class_ref = config.auto_map[cls.__name__]
112
+ if "--" in class_ref:
113
+ repo_id, class_ref = class_ref.split("--")
114
+ else:
115
+ repo_id = config.name_or_path
116
+ model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
117
+ if os.path.isdir(config._name_or_path):
118
+ model_class.register_for_auto_class(cls.__name__)
119
+ else:
120
+ cls.register(config.__class__, model_class, exist_ok=True)
121
+ _ = kwargs.pop("code_revision", None)
122
+ model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {})
123
+ return model_class._from_config(config, **kwargs)
124
+ elif type(config) in cls._model_mapping.keys():
125
+ model_class = _get_model_class(config, cls._model_mapping)
126
+ model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {})
127
+ return model_class._from_config(config, **kwargs)
128
+
129
+ raise ValueError(
130
+ f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
131
+ f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
132
+ )
config.json ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/mnt/matylda5/ipoloka/IS24_DeCRED/multidomain/normalised_data/ED_base/checkpoint-233632",
3
+ "architectures": [
4
+ "JointCTCAttentionEncoderDecoder"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_decred.JointCTCAttentionEncoderDecoderConfig",
8
+ "AutoModelForSpeechSeq2Seq": "modeling_decred.JointCTCAttentionEncoderDecoder"
9
+ },
10
+ "ctc_weight": 0.3,
11
+ "decoder": {
12
+ "_name_or_path": "Lakoc/gpt2_8l_512h",
13
+ "activation_function": "gelu_new",
14
+ "add_cross_attention": true,
15
+ "architectures": null,
16
+ "attn_pdrop": 0.1,
17
+ "bad_words_ids": null,
18
+ "begin_suppress_tokens": null,
19
+ "bos_token_id": 50256,
20
+ "chunk_size_feed_forward": 0,
21
+ "cross_attention_hidden_size": null,
22
+ "decoder_start_token_id": null,
23
+ "diversity_penalty": 0.0,
24
+ "do_sample": false,
25
+ "early_stopping": false,
26
+ "embd_pdrop": 0.1,
27
+ "encoder_no_repeat_ngram_size": 0,
28
+ "eos_token_id": 50256,
29
+ "exponential_decay_length_penalty": null,
30
+ "finetuning_task": null,
31
+ "forced_bos_token_id": null,
32
+ "forced_eos_token_id": null,
33
+ "id2label": {
34
+ "0": "LABEL_0",
35
+ "1": "LABEL_1"
36
+ },
37
+ "initializer_range": 0.02,
38
+ "is_decoder": true,
39
+ "is_encoder_decoder": false,
40
+ "label2id": {
41
+ "LABEL_0": 0,
42
+ "LABEL_1": 1
43
+ },
44
+ "layer_norm_epsilon": 1e-05,
45
+ "length_penalty": 1.0,
46
+ "max_length": 20,
47
+ "min_length": 0,
48
+ "model_type": "gpt2",
49
+ "n_embd": 512,
50
+ "n_head": 4,
51
+ "n_inner": 2048,
52
+ "n_layer": 8,
53
+ "n_positions": 1024,
54
+ "no_repeat_ngram_size": 0,
55
+ "num_beam_groups": 1,
56
+ "num_beams": 1,
57
+ "num_return_sequences": 1,
58
+ "output_attentions": false,
59
+ "output_hidden_size": 512,
60
+ "output_hidden_states": false,
61
+ "output_scores": false,
62
+ "pad_token_id": null,
63
+ "pos_emb_fixed": true,
64
+ "prefix": null,
65
+ "problem_type": null,
66
+ "pruned_heads": {},
67
+ "remove_invalid_values": false,
68
+ "reorder_and_upcast_attn": false,
69
+ "repetition_penalty": 1.0,
70
+ "resid_pdrop": 0.1,
71
+ "return_dict": true,
72
+ "return_dict_in_generate": false,
73
+ "scale_attn_by_inverse_layer_idx": false,
74
+ "scale_attn_weights": true,
75
+ "sep_token_id": null,
76
+ "summary_activation": null,
77
+ "summary_first_dropout": 0.1,
78
+ "summary_proj_to_labels": true,
79
+ "summary_type": "cls_index",
80
+ "summary_use_proj": true,
81
+ "suppress_tokens": null,
82
+ "task_specific_params": null,
83
+ "temperature": 1.0,
84
+ "tf_legacy_loss": false,
85
+ "tie_encoder_decoder": false,
86
+ "tie_word_embeddings": false,
87
+ "tokenizer_class": null,
88
+ "top_k": 50,
89
+ "top_p": 1.0,
90
+ "torch_dtype": null,
91
+ "torchscript": false,
92
+ "typical_p": 1.0,
93
+ "use_bfloat16": false,
94
+ "use_cache": true,
95
+ "vocab_size": 5000
96
+ },
97
+ "decoder_pos_emb_fixed": true,
98
+ "decoder_start_token_id": 0,
99
+ "decoder_vocab_size": 5000,
100
+ "encoder": {
101
+ "_name_or_path": "Lakoc/ebranchformer_16l_512h",
102
+ "activation_dropout": 0.1,
103
+ "adapter_attn_dim": null,
104
+ "adapter_kernel_size": 3,
105
+ "adapter_stride": 2,
106
+ "add_adapter": false,
107
+ "add_cross_attention": false,
108
+ "apply_spec_augment": false,
109
+ "apply_time_warp": false,
110
+ "architectures": null,
111
+ "attention_dropout": 0.1,
112
+ "bad_words_ids": null,
113
+ "begin_suppress_tokens": null,
114
+ "bos_token_id": 1,
115
+ "chunk_size_feed_forward": 0,
116
+ "classifier_proj_size": 256,
117
+ "codevector_dim": 256,
118
+ "conformer_conv_dropout": 0.1,
119
+ "contrastive_logits_temperature": 0.1,
120
+ "conv_bias": false,
121
+ "conv_depthwise_kernel_size": 31,
122
+ "conv_dim": [
123
+ 512,
124
+ 512
125
+ ],
126
+ "conv_kernel": [
127
+ 3,
128
+ 3
129
+ ],
130
+ "conv_stride": [
131
+ 2,
132
+ 2
133
+ ],
134
+ "cross_attention_hidden_size": null,
135
+ "csgu_activation": "identity",
136
+ "csgu_conv_dropout": 0.1,
137
+ "csgu_kernel_size": 31,
138
+ "csgu_use_linear_after_conv": false,
139
+ "ctc_loss_reduction": "mean",
140
+ "ctc_zero_infinity": true,
141
+ "decoder_start_token_id": null,
142
+ "diversity_loss_weight": 0.1,
143
+ "diversity_penalty": 0.0,
144
+ "do_sample": false,
145
+ "do_stable_layer_norm": false,
146
+ "early_stopping": false,
147
+ "encoder_no_repeat_ngram_size": 0,
148
+ "eos_token_id": 2,
149
+ "expect_2d_input": true,
150
+ "exponential_decay_length_penalty": null,
151
+ "fe_position_embeddings": true,
152
+ "feat_extract_activation": "gelu",
153
+ "feat_extract_norm": "group",
154
+ "feat_proj_dropout": 0.0,
155
+ "feat_quantizer_dropout": 0.0,
156
+ "final_dropout": 0.1,
157
+ "finetuning_task": null,
158
+ "forced_bos_token_id": null,
159
+ "forced_eos_token_id": null,
160
+ "hidden_act": "gelu",
161
+ "hidden_dropout": 0.1,
162
+ "hidden_size": 512,
163
+ "id2label": {
164
+ "0": "LABEL_0",
165
+ "1": "LABEL_1"
166
+ },
167
+ "initializer_range": 0.02,
168
+ "intermediate_size": 2048,
169
+ "is_causal": false,
170
+ "is_decoder": false,
171
+ "is_encoder_decoder": false,
172
+ "label2id": {
173
+ "LABEL_0": 0,
174
+ "LABEL_1": 1
175
+ },
176
+ "layer_norm_eps": 1e-05,
177
+ "layerdrop": 0.0,
178
+ "length_penalty": 1.0,
179
+ "mask_feature_length": 10,
180
+ "mask_feature_min_masks": 0,
181
+ "mask_feature_prob": 0.0,
182
+ "mask_time_length": 10,
183
+ "mask_time_min_masks": 2,
184
+ "mask_time_prob": 0.05,
185
+ "max_length": 20,
186
+ "max_source_positions": 1024,
187
+ "merge_conv_kernel": 31,
188
+ "min_length": 0,
189
+ "model_type": "wav2vec2-ebranchformer",
190
+ "no_repeat_ngram_size": 0,
191
+ "num_adapter_layers": 3,
192
+ "num_attention_heads": 4,
193
+ "num_beam_groups": 1,
194
+ "num_beams": 1,
195
+ "num_codevector_groups": 2,
196
+ "num_codevectors_per_group": 320,
197
+ "num_conv_pos_embedding_groups": 16,
198
+ "num_conv_pos_embeddings": 128,
199
+ "num_feat_extract_layers": 2,
200
+ "num_hidden_layers": 16,
201
+ "num_mel_bins": 80,
202
+ "num_negatives": 100,
203
+ "num_return_sequences": 1,
204
+ "output_attentions": false,
205
+ "output_hidden_size": 512,
206
+ "output_hidden_states": false,
207
+ "output_scores": false,
208
+ "pad_token_id": 3,
209
+ "position_embeddings_type": "relative",
210
+ "prefix": null,
211
+ "problem_type": null,
212
+ "proj_codevector_dim": 256,
213
+ "pruned_heads": {},
214
+ "remove_invalid_values": false,
215
+ "repetition_penalty": 1.0,
216
+ "return_dict": true,
217
+ "return_dict_in_generate": false,
218
+ "rotary_embedding_base": 10000,
219
+ "second_dim_input_size": 80,
220
+ "sep_token_id": null,
221
+ "suppress_tokens": null,
222
+ "task_specific_params": null,
223
+ "tdnn_dilation": [
224
+ 1,
225
+ 2,
226
+ 3,
227
+ 1,
228
+ 1
229
+ ],
230
+ "tdnn_dim": [
231
+ 512,
232
+ 512,
233
+ 512,
234
+ 512,
235
+ 1500
236
+ ],
237
+ "tdnn_kernel": [
238
+ 5,
239
+ 3,
240
+ 3,
241
+ 1,
242
+ 1
243
+ ],
244
+ "temperature": 1.0,
245
+ "tf_legacy_loss": false,
246
+ "tie_encoder_decoder": false,
247
+ "tie_word_embeddings": true,
248
+ "time_warp_mode": "bicubic",
249
+ "time_warp_window": 5,
250
+ "tokenizer_class": null,
251
+ "top_k": 50,
252
+ "top_p": 1.0,
253
+ "torch_dtype": null,
254
+ "torchscript": false,
255
+ "typical_p": 1.0,
256
+ "use_bfloat16": false,
257
+ "use_fbanks": true,
258
+ "use_macaron_ff": true,
259
+ "use_weighted_layer_sum": false,
260
+ "vocab_size": 5000,
261
+ "xvector_output_dim": 512
262
+ },
263
+ "encoder_ctc_loss_reduction": "mean",
264
+ "encoder_expect_2d_input": true,
265
+ "encoder_layerdrop": 0.0,
266
+ "encoder_pad_token_id": 3,
267
+ "encoder_second_dim_input_size": 80,
268
+ "encoder_vocab_size": 5000,
269
+ "is_encoder_decoder": true,
270
+ "lsm_factor": 0.1,
271
+ "model_type": "joint_aed_ctc_speech-encoder-decoder",
272
+ "pad_token_id": 3,
273
+ "shared_lm_head": false,
274
+ "tie_word_embeddings": false,
275
+ "tokenizer_class": "<class 'transformers.tokenization_utils_fast.PreTrainedTokenizerFast'>",
276
+ "torch_dtype": "float32",
277
+ "transformers_version": "4.39.3"
278
+ }
configuration_decred.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModelForCausalLM, SpeechEncoderDecoderConfig
2
+
3
+ from .auto_wrappers import CustomAutoModelForCTC
4
+ from .e_branchformer import Wav2Vec2EBranchformerConfig, Wav2Vec2EBranchformerForCTC
5
+ from .multi_head_gpt2 import GPT2LMMultiHeadModel, GPT2MultiHeadConfig
6
+ from .residual_clasiffier_gpt2 import (
7
+ GPT2ResidualsLMHeadConfig,
8
+ GPT2ResidualsLMHeadModel,
9
+ )
10
+
11
+ AutoConfig.register("gpt2-multi-head", GPT2MultiHeadConfig)
12
+ AutoModelForCausalLM.register(GPT2MultiHeadConfig, GPT2LMMultiHeadModel)
13
+
14
+ AutoConfig.register("gpt2-residuals-head", GPT2ResidualsLMHeadConfig)
15
+ AutoModelForCausalLM.register(GPT2ResidualsLMHeadConfig, GPT2ResidualsLMHeadModel)
16
+
17
+ AutoConfig.register("wav2vec2-ebranchformer", Wav2Vec2EBranchformerConfig)
18
+ CustomAutoModelForCTC.register(Wav2Vec2EBranchformerConfig, Wav2Vec2EBranchformerForCTC)
19
+
20
+
21
+ class JointCTCAttentionEncoderDecoderConfig(SpeechEncoderDecoderConfig):
22
+ model_type = "joint_aed_ctc_speech-encoder-decoder"
23
+ is_composition = True
ctc_scorer.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # Copied from: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
3
+ import torch
4
+ from transformers import LogitsProcessor
5
+
6
+
7
+ class CTCPrefixScoreTH(object):
8
+ """Batch processing of CTCPrefixScore
9
+
10
+ which is based on Algorithm 2 in WATANABE et al.
11
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
12
+ but extended to efficiently compute the label probablities for multiple
13
+ hypotheses simultaneously
14
+ See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
15
+ Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
16
+ """
17
+
18
+ def __init__(self, x, xlens, blank, eos, margin=0):
19
+ """Construct CTC prefix scorer
20
+
21
+ :param torch.Tensor x: input label posterior sequences (B, T, O)
22
+ :param torch.Tensor xlens: input lengths (B,)
23
+ :param int blank: blank label id
24
+ :param int eos: end-of-sequence id
25
+ :param int margin: margin parameter for windowing (0 means no windowing)
26
+ """
27
+ # In the comment lines,
28
+ # we assume T: input_length, B: batch size, W: beam width, O: output dim.
29
+ self.logzero = -10000000000.0
30
+ self.blank = blank
31
+ self.eos = eos
32
+ self.batch = x.size(0)
33
+ self.input_length = x.size(1)
34
+ self.odim = x.size(2)
35
+ self.dtype = x.dtype
36
+ self.device = torch.device("cuda:%d" % x.get_device()) if x.is_cuda else torch.device("cpu")
37
+ # Pad the rest of posteriors in the batch
38
+ # TODO(takaaki-hori): need a better way without for-loops
39
+ for i, l in enumerate(xlens):
40
+ if l < self.input_length:
41
+ x[i, l:, :] = self.logzero
42
+ x[i, l:, blank] = 0
43
+ # Reshape input x
44
+ xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
45
+ xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
46
+ self.x = torch.stack([xn, xb]) # (2, T, B, O)
47
+ self.end_frames = torch.as_tensor(xlens) - 1
48
+
49
+ # Setup CTC windowing
50
+ self.margin = margin
51
+ if margin > 0:
52
+ self.frame_ids = torch.arange(self.input_length, dtype=self.dtype, device=self.device)
53
+ # Base indices for index conversion
54
+ self.idx_bh = None
55
+ self.idx_b = torch.arange(self.batch, device=self.device)
56
+ self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)
57
+
58
+ def __call__(self, y, state, scoring_ids=None, att_w=None):
59
+ """Compute CTC prefix scores for next labels
60
+
61
+ :param list y: prefix label sequences
62
+ :param tuple state: previous CTC state
63
+ :param torch.Tensor att_w: attention weights to decide CTC window
64
+ :return new_state, ctc_local_scores (BW, O)
65
+ """
66
+
67
+ # print(self.tokenizer.batch_decode(y))
68
+ output_length = len(y[0]) - 1 # ignore sos
69
+ last_ids = [yi[-1] for yi in y] # last output label ids
70
+ n_bh = len(last_ids) # batch * hyps
71
+ n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps
72
+ self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0
73
+ # prepare state info
74
+ if state is None:
75
+ r_prev = torch.full(
76
+ (self.input_length, 2, self.batch, n_hyps),
77
+ self.logzero,
78
+ dtype=self.dtype,
79
+ device=self.device,
80
+ )
81
+ r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
82
+ r_prev = r_prev.view(-1, 2, n_bh)
83
+ s_prev = 0.0
84
+ f_min_prev = 0
85
+ f_max_prev = 1
86
+ else:
87
+ r_prev, s_prev, f_min_prev, f_max_prev = state
88
+
89
+ # select input dimensions for decred_scoring
90
+ if self.scoring_num > 0:
91
+ scoring_idmap = torch.full((n_bh, self.odim), -1, dtype=torch.long, device=self.device)
92
+ snum = self.scoring_num
93
+ if self.idx_bh is None or n_bh > len(self.idx_bh):
94
+ self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1)
95
+ scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(snum, device=self.device)
96
+ scoring_idx = (scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)).view(-1)
97
+ x_ = torch.index_select(self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx).view(2, -1, n_bh, snum)
98
+ else:
99
+ scoring_ids = None
100
+ scoring_idmap = None
101
+ snum = self.odim
102
+ x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum)
103
+
104
+ # new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
105
+ # that corresponds to r_t^n(h) and r_t^b(h) in a batch.
106
+ r = torch.full(
107
+ (self.input_length, 2, n_bh, snum),
108
+ self.logzero,
109
+ dtype=self.dtype,
110
+ device=self.device,
111
+ )
112
+ if output_length == 0:
113
+ r[0, 0] = x_[0, 0]
114
+
115
+ r_sum = torch.logsumexp(r_prev, 1)
116
+ log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
117
+ if scoring_ids is not None:
118
+ for idx in range(n_bh):
119
+ pos = scoring_idmap[idx, last_ids[idx]]
120
+ if pos >= 0:
121
+ log_phi[:, idx, pos] = r_prev[:, 1, idx]
122
+ else:
123
+ for idx in range(n_bh):
124
+ log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
125
+
126
+ # decide start and end frames based on attention weights
127
+ if att_w is not None and self.margin > 0:
128
+ f_arg = torch.matmul(att_w, self.frame_ids)
129
+ f_min = max(int(f_arg.min().cpu()), f_min_prev)
130
+ f_max = max(int(f_arg.max().cpu()), f_max_prev)
131
+ start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
132
+ end = min(f_max + self.margin, self.input_length)
133
+ else:
134
+ f_min = f_max = 0
135
+ start = max(output_length, 1)
136
+ end = self.input_length
137
+
138
+ if start > end:
139
+ return torch.full_like(s_prev, self.logzero), (
140
+ r,
141
+ torch.full_like(s_prev, self.logzero),
142
+ f_min,
143
+ f_max,
144
+ scoring_idmap,
145
+ )
146
+
147
+ # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
148
+ for t in range(start, end):
149
+ rp = r[t - 1]
150
+ rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(2, 2, n_bh, snum)
151
+ r[t] = torch.logsumexp(rr, 1) + x_[:, t]
152
+
153
+ # compute log prefix probabilities log(psi)
154
+ log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0]
155
+ if scoring_ids is not None:
156
+ log_psi = torch.full((n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device)
157
+ log_psi_ = torch.logsumexp(
158
+ torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
159
+ dim=0,
160
+ )
161
+ for si in range(n_bh):
162
+ log_psi[si, scoring_ids[si]] = log_psi_[si]
163
+ else:
164
+ log_psi = torch.logsumexp(
165
+ torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
166
+ dim=0,
167
+ )
168
+
169
+ # for si in range(n_bh):
170
+ # log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
171
+
172
+ # exclude blank probs
173
+ log_psi[:, self.blank] = self.logzero
174
+
175
+ token_scores = log_psi - s_prev
176
+ token_scores[token_scores == 0] = self.logzero
177
+
178
+ return token_scores, (r, log_psi, f_min, f_max, scoring_idmap)
179
+
180
+ def index_select_state(self, state, best_ids):
181
+ """Select CTC states according to best ids
182
+
183
+ :param state : CTC state
184
+ :param best_ids : index numbers selected by beam pruning (B, W)
185
+ :return selected_state
186
+ """
187
+ r, s, f_min, f_max, scoring_idmap = state
188
+ # convert ids to BHO space
189
+ n_bh = len(s)
190
+ n_hyps = n_bh // self.batch
191
+ vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1)
192
+ # select hypothesis scores
193
+ s_new = torch.index_select(s.view(-1), 0, vidx)
194
+ s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
195
+ # convert ids to BHS space (S: scoring_num)
196
+ if scoring_idmap is not None:
197
+ snum = self.scoring_num
198
+ hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(-1)
199
+ label_ids = torch.fmod(best_ids, self.odim).view(-1)
200
+ score_idx = scoring_idmap[hyp_idx, label_ids]
201
+ score_idx[score_idx == -1] = 0
202
+ vidx = score_idx + hyp_idx * snum
203
+ else:
204
+ snum = self.odim
205
+ # select forward probabilities
206
+ r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(-1, 2, n_bh)
207
+ return r_new, s_new, f_min, f_max
208
+
209
+ def extend_prob(self, x):
210
+ """Extend CTC prob.
211
+
212
+ :param torch.Tensor x: input label posterior sequences (B, T, O)
213
+ """
214
+
215
+ if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
216
+ # Pad the rest of posteriors in the batch
217
+ # TODO(takaaki-hori): need a better way without for-loops
218
+ xlens = [x.size(1)]
219
+ for i, l in enumerate(xlens):
220
+ if l < self.input_length:
221
+ x[i, l:, :] = self.logzero
222
+ x[i, l:, self.blank] = 0
223
+ tmp_x = self.x
224
+ xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
225
+ xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
226
+ self.x = torch.stack([xn, xb]) # (2, T, B, O)
227
+ self.x[:, : tmp_x.shape[1], :, :] = tmp_x
228
+ self.input_length = x.size(1)
229
+ self.end_frames = torch.as_tensor(xlens) - 1
230
+
231
+ def extend_state(self, state):
232
+ """Compute CTC prefix state.
233
+
234
+
235
+ :param state : CTC state
236
+ :return ctc_state
237
+ """
238
+
239
+ if state is None:
240
+ # nothing to do
241
+ return state
242
+ else:
243
+ r_prev, s_prev, f_min_prev, f_max_prev = state
244
+
245
+ r_prev_new = torch.full(
246
+ (self.input_length, 2),
247
+ self.logzero,
248
+ dtype=self.dtype,
249
+ device=self.device,
250
+ )
251
+ start = max(r_prev.shape[0], 1)
252
+ r_prev_new[0:start] = r_prev
253
+ for t in range(start, self.input_length):
254
+ r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank]
255
+
256
+ return (r_prev_new, s_prev, f_min_prev, f_max_prev)
257
+
258
+
259
+ class CTCRescorerLogitsProcessor(LogitsProcessor):
260
+ def __init__(
261
+ self,
262
+ encoder_logits: torch.FloatTensor,
263
+ encoder_output_lens: torch.LongTensor,
264
+ pad_token_id: int,
265
+ eos_token_id: int,
266
+ ctc_margin: int,
267
+ ctc_weight: float,
268
+ num_beams: int,
269
+ space_token_id: int,
270
+ apply_eos_space_trick: bool,
271
+ eos_space_trick_weight: float,
272
+ debug: bool = False,
273
+ ):
274
+ super().__init__()
275
+ # reduce_lens_by = (encoder_logits.argmax(dim=-1) == eos_token_id).sum(dim=-1)
276
+ # encoder_output_lens = encoder_output_lens - reduce_lens_by
277
+ self.pad_token_id = pad_token_id
278
+ self.ctc_prefix_scorer = CTCPrefixScoreTH(
279
+ torch.nn.functional.log_softmax(encoder_logits, dim=-1),
280
+ encoder_output_lens,
281
+ pad_token_id,
282
+ eos_token_id,
283
+ ctc_margin,
284
+ )
285
+ self.ctc_weight = ctc_weight
286
+ self.ctc_states = None
287
+ self.num_beams = num_beams
288
+ self.eos_token_id = eos_token_id
289
+ self.apply_eos_space_trick = apply_eos_space_trick
290
+ self.space_token_id = space_token_id
291
+ self.eos_space_trick_weight = eos_space_trick_weight
292
+ self.debug = debug
293
+
294
+ @staticmethod
295
+ def analyze_predictions(
296
+ scores, ctc_scores, next_token_scores, input_ids, k=10, tokenizer="Lakoc/english_corpus_uni5000_normalized"
297
+ ):
298
+ from transformers import AutoTokenizer
299
+
300
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
301
+ best_att_ids = scores.topk(k=k, dim=1)
302
+ best_ctc_ids = ctc_scores.topk(k=k, dim=1)
303
+ best_ids = next_token_scores.topk(k=k, dim=1)
304
+
305
+ def print_prediction(best_ids, name):
306
+ new_tensor = torch.zeros((best_ids.indices.shape[0], best_ids.indices.shape[1] * 2), dtype=torch.long)
307
+ new_tensor[:, 0::2] = best_ids.indices
308
+ new_tensor[:, 1::2] = 4976
309
+ print(f"{name}:")
310
+ for index, (next_ids, scores) in enumerate(zip(tokenizer.batch_decode(new_tensor), best_ids.values)):
311
+ print(f"HYP {index}:\n{next_ids} {scores}")
312
+
313
+ print(f"PREFIX:")
314
+ for index, prefix in enumerate(tokenizer.batch_decode(input_ids)):
315
+ print(f"HYP {index}:\n{prefix}")
316
+ print_prediction(best_att_ids, "ATT_SCORES")
317
+ print()
318
+ print_prediction(best_ctc_ids, "CTC_SCORES")
319
+ print()
320
+ print(f"CTC_EOS: {ctc_scores[:, 1]}")
321
+ print_prediction(best_ids, "NEXT_TOKEN_SCORES")
322
+ print()
323
+
324
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
325
+ scores[:, self.pad_token_id] = self.ctc_prefix_scorer.logzero
326
+ if self.ctc_states is not None:
327
+ self.ctc_states = self.ctc_prefix_scorer.index_select_state(
328
+ self.ctc_states, input_ids[:, -1].reshape(-1, self.num_beams)
329
+ )
330
+ ctc_scores, ctc_states = self.ctc_prefix_scorer(input_ids, self.ctc_states)
331
+ self.ctc_states = ctc_states
332
+ next_token_scores = (1 - self.ctc_weight) * scores + self.ctc_weight * ctc_scores
333
+ if self.apply_eos_space_trick:
334
+ space_eos_conflict = torch.logical_and(
335
+ scores.argmax(dim=1) == self.eos_token_id, ctc_scores.argmax(dim=1) == self.space_token_id
336
+ )
337
+ if space_eos_conflict.any():
338
+ apply_trick_on = torch.logical_and(
339
+ torch.logical_and(
340
+ space_eos_conflict,
341
+ next_token_scores[:, self.eos_token_id] < next_token_scores[:, self.space_token_id],
342
+ ),
343
+ self.eos_space_trick_weight * next_token_scores[:, self.eos_token_id]
344
+ > next_token_scores[:, self.space_token_id],
345
+ )
346
+ if apply_trick_on.any():
347
+ next_token_scores[apply_trick_on, self.eos_token_id] = (
348
+ next_token_scores[apply_trick_on, self.eos_token_id] * self.eos_space_trick_weight
349
+ )
350
+
351
+ if self.debug:
352
+ self.analyze_predictions(scores, ctc_scores, next_token_scores, input_ids)
353
+
354
+ return next_token_scores
355
+
356
+
357
+ class LogSoftmaxProcessor(LogitsProcessor):
358
+ def __init__(
359
+ self,
360
+ ):
361
+ super().__init__()
362
+
363
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
364
+ scores = torch.nn.functional.log_softmax(scores, dim=-1)
365
+ return scores
e_branchformer.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch Wav2Vec2-Ebranchformer model."""
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+ from transformers.activations import ACT2FN
9
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
10
+ Wav2Vec2Config,
11
+ Wav2Vec2ForCTC,
12
+ Wav2Vec2ForPreTraining,
13
+ )
14
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
15
+ Wav2Vec2ConformerConfig,
16
+ Wav2Vec2ConformerEncoder,
17
+ )
18
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
19
+ Wav2Vec2ConformerFeedForward as Wav2Vec2EBranchformerFeedForward,
20
+ )
21
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
22
+ Wav2Vec2ConformerModel,
23
+ )
24
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
25
+ Wav2Vec2ConformerSelfAttention as Wav2Vec2EBranchformerSelfAttention,
26
+ )
27
+ from transformers.utils import logging
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class Wav2Vec2EBranchformerConfig(Wav2Vec2ConformerConfig, Wav2Vec2Config):
33
+ """Config for EBranhformer model extending conformer."""
34
+
35
+ model_type = "wav2vec2-ebranchformer"
36
+
37
+ def __init__(
38
+ self,
39
+ ebranchformer_conv_dropout=0.1,
40
+ csgu_activation="identity",
41
+ csgu_kernel_size=31,
42
+ csgu_use_linear_after_conv=False,
43
+ merge_conv_kernel=31,
44
+ use_macaron_ff=True,
45
+ **kwargs,
46
+ ):
47
+ super().__init__(**kwargs)
48
+ # EBranchformer related params
49
+ self.csgu_kernel_size = csgu_kernel_size
50
+ self.csgu_activation = csgu_activation
51
+ self.csgu_conv_dropout = ebranchformer_conv_dropout
52
+ self.csgu_use_linear_after_conv = csgu_use_linear_after_conv
53
+ self.merge_conv_kernel = merge_conv_kernel
54
+ self.use_macaron_ff = use_macaron_ff
55
+
56
+
57
+ class ConvolutionalSpatialGatingUnit(torch.nn.Module):
58
+ """Convolutional Spatial Gating Unit (CSGU)."""
59
+
60
+ def __init__(self, config: Wav2Vec2EBranchformerConfig):
61
+ super().__init__()
62
+
63
+ n_channels = config.intermediate_size // 2 # split input channels
64
+ self.norm = torch.nn.LayerNorm(n_channels)
65
+ self.conv = torch.nn.Conv1d(
66
+ n_channels,
67
+ n_channels,
68
+ config.csgu_kernel_size,
69
+ 1,
70
+ (config.csgu_kernel_size - 1) // 2,
71
+ groups=n_channels,
72
+ )
73
+ if config.csgu_use_linear_after_conv:
74
+ self.linear = torch.nn.Linear(n_channels, n_channels)
75
+ else:
76
+ self.linear = None
77
+
78
+ if config.csgu_activation == "identity":
79
+ self.act = torch.nn.Identity()
80
+ else:
81
+ self.act = ACT2FN[config.csgu_activation]
82
+
83
+ self.dropout = torch.nn.Dropout(config.csgu_conv_dropout)
84
+
85
+ def forward(self, hidden_states: torch.FloatTensor):
86
+ """Forward method
87
+
88
+ Args:
89
+ hidden_states (torch.Tensor): (N, T, D)
90
+
91
+ Returns:
92
+ out (torch.Tensor): (N, T, D/2)
93
+ """
94
+
95
+ x_r, x_g = hidden_states.chunk(2, dim=-1)
96
+
97
+ x_g = self.norm(x_g) # (N, T, D/2)
98
+ x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) # (N, T, D/2)
99
+ if self.linear is not None:
100
+ x_g = self.linear(x_g)
101
+
102
+ x_g = self.act(x_g)
103
+ hidden_states = x_r * x_g # (N, T, D/2)
104
+ hidden_states = self.dropout(hidden_states)
105
+ return hidden_states
106
+
107
+
108
+ class ConvolutionalGatingMLP(torch.nn.Module):
109
+ """Convolutional Gating MLP (cgMLP)."""
110
+
111
+ def __init__(self, config: Wav2Vec2EBranchformerConfig):
112
+ super().__init__()
113
+ self.channel_proj1 = torch.nn.Sequential(
114
+ torch.nn.Linear(config.hidden_size, config.intermediate_size), torch.nn.GELU()
115
+ )
116
+ self.csgu = ConvolutionalSpatialGatingUnit(config)
117
+ self.channel_proj2 = torch.nn.Linear(config.intermediate_size // 2, config.hidden_size)
118
+
119
+ def forward(self, hidden_states: torch.FloatTensor):
120
+ hidden_states = self.channel_proj1(hidden_states) # hidden_size -> intermediate_size
121
+ hidden_states = self.csgu(hidden_states) # intermediate_size -> intermediate_size/2
122
+ hidden_states = self.channel_proj2(hidden_states) # intermediate_size/2 -> hidden_size
123
+ return hidden_states
124
+
125
+
126
+ class Wav2Vec2EBranchformerEncoderLayer(nn.Module):
127
+ def __init__(self, config: Wav2Vec2EBranchformerConfig):
128
+ super().__init__()
129
+ embed_dim = config.hidden_size
130
+ dropout = config.attention_dropout
131
+
132
+ # Feed-forward 1
133
+ if config.use_macaron_ff:
134
+ self.ff1 = nn.Sequential(nn.LayerNorm(embed_dim), Wav2Vec2EBranchformerFeedForward(config))
135
+
136
+ # Self-Attention
137
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
138
+ self.self_attn_dropout = torch.nn.Dropout(dropout)
139
+ self.self_attn = Wav2Vec2EBranchformerSelfAttention(config)
140
+
141
+ # cgMLP
142
+ self.cgMLP = ConvolutionalGatingMLP(config)
143
+ self.cgMLP_layer_norm = nn.LayerNorm(config.hidden_size)
144
+ self.cgMLP_dropout = torch.nn.Dropout(dropout)
145
+
146
+ # Merge
147
+ self.final_dropout = torch.nn.Dropout(dropout)
148
+ self.merge_proj = torch.nn.Linear(embed_dim + embed_dim, embed_dim)
149
+ self.depthwise_conv_fusion = torch.nn.Conv1d(
150
+ embed_dim + embed_dim,
151
+ embed_dim + embed_dim,
152
+ kernel_size=config.merge_conv_kernel,
153
+ stride=1,
154
+ padding=(config.merge_conv_kernel - 1) // 2,
155
+ groups=embed_dim + embed_dim,
156
+ bias=True,
157
+ )
158
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
159
+
160
+ # Feed-forward 2
161
+ if config.use_macaron_ff:
162
+ self.ff2 = nn.Sequential(nn.LayerNorm(embed_dim), Wav2Vec2EBranchformerFeedForward(config))
163
+
164
+ def forward(
165
+ self,
166
+ hidden_states: torch.FloatTensor,
167
+ attention_mask: Optional[torch.Tensor] = None,
168
+ relative_position_embeddings: Optional[torch.Tensor] = None,
169
+ output_attentions: bool = False,
170
+ ):
171
+ # 1. Optional ff1
172
+ if self.ff1:
173
+ residual = hidden_states
174
+ hidden_states = residual + 0.5 * self.ff1(hidden_states)
175
+
176
+ # 2. Split input to three branches
177
+ residual = hidden_states
178
+ global_branch = hidden_states
179
+ local_branch = hidden_states
180
+
181
+ # 3. Self-Attention branch
182
+ global_branch = self.self_attn_layer_norm(global_branch)
183
+ global_branch, attn_weigts = self.self_attn(
184
+ hidden_states=global_branch,
185
+ attention_mask=attention_mask,
186
+ relative_position_embeddings=relative_position_embeddings,
187
+ output_attentions=output_attentions,
188
+ )
189
+ global_branch = self.self_attn_dropout(global_branch)
190
+
191
+ # 4. cgMLP Branch
192
+ local_branch = self.cgMLP_layer_norm(local_branch)
193
+ local_branch = self.cgMLP(local_branch)
194
+
195
+ # 5. Merge operator
196
+ # a, concat
197
+ hidden_states = torch.cat([global_branch, local_branch], dim=-1)
198
+ merge_residual = hidden_states
199
+ # b, depth-wise conv mixing
200
+ hidden_states = merge_residual + self.depthwise_conv_fusion(hidden_states.transpose(1, 2)).transpose(1, 2)
201
+ # c, project back to original size and final dropout
202
+ hidden_states = self.final_dropout(self.merge_proj(hidden_states))
203
+
204
+ # 6. Add residual
205
+ hidden_states = residual + hidden_states
206
+
207
+ # 7. Optional ff2
208
+ if self.ff2:
209
+ residual = hidden_states
210
+ hidden_states = residual + 0.5 * self.ff2(hidden_states)
211
+
212
+ # 8. Final layer norm
213
+ hidden_states = self.final_layer_norm(hidden_states)
214
+ return hidden_states, attn_weigts
215
+
216
+
217
+ class Wav2Vec2EBranchformerEncoder(Wav2Vec2ConformerEncoder):
218
+ def __init__(self, config: Wav2Vec2EBranchformerConfig):
219
+ super().__init__(config)
220
+ self.layers = nn.ModuleList(
221
+ [Wav2Vec2EBranchformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]
222
+ )
223
+ self.pos_conv_embed = None
224
+
225
+
226
+ class Wav2Vec2EBranchformerModel(Wav2Vec2ConformerModel):
227
+ def __init__(self, config: Wav2Vec2EBranchformerConfig):
228
+ super().__init__(config)
229
+ self.encoder = Wav2Vec2EBranchformerEncoder(config)
230
+
231
+ # Initialize weights and apply final processing
232
+ self.post_init()
233
+
234
+
235
+ class Wav2Vec2EBranchformerForPreTraining(Wav2Vec2ForPreTraining):
236
+ config_class = Wav2Vec2EBranchformerConfig
237
+ base_model_prefix = "wav2vec2"
238
+
239
+ def __init__(self, config: Wav2Vec2EBranchformerConfig):
240
+ super().__init__(config)
241
+ self.wav2vec2 = Wav2Vec2EBranchformerModel(config)
242
+ self.post_init()
243
+
244
+
245
+ class Wav2Vec2EBranchformerForCTC(Wav2Vec2ForCTC):
246
+ config_class = Wav2Vec2EBranchformerConfig
247
+ base_model_prefix = "wav2vec2"
248
+
249
+ def __init__(self, config: Wav2Vec2EBranchformerConfig):
250
+ super().__init__(config)
251
+ self.wav2vec2 = Wav2Vec2EBranchformerModel(config)
252
+ self.post_init()
embeddings.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class AdaptiveEmbedding(nn.Module):
6
+ def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False):
7
+ super().__init__()
8
+
9
+ self.n_token = n_token
10
+ self.d_embed = d_embed
11
+
12
+ self.cutoffs = cutoffs + [n_token]
13
+ self.div_val = div_val
14
+ self.d_proj = d_proj
15
+
16
+ self.emb_scale = d_proj**0.5
17
+
18
+ self.cutoff_ends = [0] + self.cutoffs
19
+
20
+ self.emb_layers = nn.ModuleList()
21
+ self.emb_projs = nn.ParameterList()
22
+ if div_val == 1:
23
+ self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0))
24
+ if d_proj != d_embed:
25
+ self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
26
+ else:
27
+ for i in range(len(self.cutoffs)):
28
+ l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
29
+ d_emb_i = d_embed // (div_val**i)
30
+ self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i))
31
+ self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
32
+
33
+ def forward(self, inp):
34
+ if self.div_val == 1:
35
+ embed = self.emb_layers[0](inp)
36
+ if self.d_proj != self.d_embed:
37
+ embed = nn.functional.linear(embed, self.emb_projs[0])
38
+ else:
39
+ param = next(self.parameters())
40
+ inp_flat = inp.view(-1)
41
+ emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device)
42
+ for i in range(len(self.cutoffs)):
43
+ l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
44
+
45
+ mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
46
+ indices_i = mask_i.nonzero().squeeze()
47
+
48
+ if indices_i.numel() == 0:
49
+ continue
50
+
51
+ inp_i = inp_flat.index_select(0, indices_i) - l_idx
52
+ emb_i = self.emb_layers[i](inp_i)
53
+ emb_i = nn.functional.linear(emb_i, self.emb_projs[i])
54
+
55
+ emb_flat.index_copy_(0, indices_i, emb_i)
56
+
57
+ embed_shape = inp.size() + (self.d_proj,)
58
+ embed = emb_flat.view(embed_shape)
59
+
60
+ embed.mul_(self.emb_scale)
61
+
62
+ return embed
63
+
64
+
65
+ class PositionalEmbeddingAux(nn.Module):
66
+ def __init__(self, demb):
67
+ super().__init__()
68
+
69
+ self.demb = demb
70
+
71
+ inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
72
+ self.register_buffer("inv_freq", inv_freq)
73
+
74
+ def forward(self, pos_seq, bsz=None):
75
+ sinusoid_inp = torch.outer(pos_seq, self.inv_freq)
76
+ pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
77
+
78
+ if bsz is not None:
79
+ return pos_emb[:, None, :].expand(-1, bsz, -1)
80
+ else:
81
+ return pos_emb[:, None, :]
82
+
83
+
84
+ class PositionalEmbedding(PositionalEmbeddingAux):
85
+ def forward(self, pos_seq, bsz=None):
86
+ return super().forward(pos_seq.squeeze(0), bsz=bsz).squeeze(1)
extractors.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers.activations import ACT2FN
4
+
5
+
6
+ class Conv2dFeatureExtractor(nn.Module):
7
+ def __init__(self, config):
8
+ super().__init__()
9
+ self.conv = torch.nn.Sequential(
10
+ *[
11
+ nn.Sequential(
12
+ nn.Conv2d(
13
+ conv_in,
14
+ out_channels=conv_out,
15
+ kernel_size=(conv_kernel, conv_kernel),
16
+ stride=(conv_stride, conv_stride),
17
+ ),
18
+ ACT2FN[config.feat_extract_activation],
19
+ )
20
+ for conv_in, conv_out, conv_kernel, conv_stride in zip(
21
+ [1, *config.conv_dim], config.conv_dim, config.conv_kernel, config.conv_stride
22
+ )
23
+ ],
24
+ )
25
+
26
+ linear_in_dim = config.conv_dim[-1] * (((config.second_dim_input_size - 1) // 2 - 1) // 2)
27
+ self.out = torch.nn.Linear(linear_in_dim, config.hidden_size, bias=True)
28
+
29
+ def forward(self, input_values: torch.Tensor) -> torch.Tensor:
30
+ hidden_states = self.conv(input_values[:, None, ...])
31
+ hidden_states = self.out(hidden_states.transpose(1, 2).flatten(2, 3))
32
+ return hidden_states.transpose(1, 2)
generation.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GenerationConfig
2
+
3
+
4
+ class GenerationConfigCustom(GenerationConfig):
5
+ def __init__(
6
+ self,
7
+ ctc_weight=0.0,
8
+ ctc_margin=0,
9
+ lm_weight=0,
10
+ lm_model=None,
11
+ space_token_id=-1,
12
+ eos_space_trick_weight=0,
13
+ apply_eos_space_trick=False,
14
+ **kwargs,
15
+ ):
16
+ super().__init__(**kwargs)
17
+ self.ctc_weight = ctc_weight
18
+ self.ctc_margin = ctc_margin
19
+ self.lm_weight = lm_weight
20
+ self.lm_model = lm_model
21
+ self.space_token_id = space_token_id
22
+ self.eos_space_trick_weight = eos_space_trick_weight
23
+ self.apply_eos_space_trick = apply_eos_space_trick
24
+
25
+ def update_from_string(self, update_str: str):
26
+ """
27
+ Updates attributes of this class with attributes from `update_str`.
28
+
29
+ The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
30
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
31
+
32
+ The keys to change have to already exist in the config object.
33
+
34
+ Args:
35
+ update_str (`str`): String with attributes that should be updated for this class.
36
+
37
+ """
38
+
39
+ d = dict(x.split("=") for x in update_str.split(";"))
40
+ for k, v in d.items():
41
+ if not hasattr(self, k):
42
+ raise ValueError(f"key {k} isn't in the original config dict")
43
+
44
+ old_v = getattr(self, k)
45
+ if isinstance(old_v, bool):
46
+ if v.lower() in ["true", "1", "y", "yes"]:
47
+ v = True
48
+ elif v.lower() in ["false", "0", "n", "no"]:
49
+ v = False
50
+ else:
51
+ raise ValueError(f"can't derive true or false from {v} (key {k})")
52
+ elif isinstance(old_v, int):
53
+ v = int(v)
54
+ elif isinstance(old_v, float):
55
+ v = float(v)
56
+ elif not isinstance(old_v, str):
57
+ raise ValueError(
58
+ f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
59
+ )
60
+
61
+ setattr(self, k, v)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9889fb78a8079980e7275129a0794a69ae674fa6858cc07935fb1d9ae6dd28b8
3
+ size 687737968
modeling_decred.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers import (
8
+ AutoConfig,
9
+ AutoModelForCausalLM,
10
+ AutoModelForSpeechSeq2Seq,
11
+ LogitsProcessor,
12
+ PretrainedConfig,
13
+ PreTrainedModel,
14
+ SpeechEncoderDecoderConfig,
15
+ SpeechEncoderDecoderModel,
16
+ StoppingCriteriaList,
17
+ )
18
+ from transformers.generation.logits_process import LogitsProcessorList
19
+ from transformers.generation.utils import GenerateOutput
20
+ from transformers.modeling_outputs import CausalLMOutput, Seq2SeqLMOutput
21
+ from transformers.models.speech_encoder_decoder.modeling_speech_encoder_decoder import (
22
+ shift_tokens_right,
23
+ )
24
+ from transformers.utils import logging
25
+
26
+ from .auto_wrappers import CustomAutoModelForCTC
27
+ from .configuration_decred import JointCTCAttentionEncoderDecoderConfig
28
+ from .ctc_scorer import CTCRescorerLogitsProcessor, LogSoftmaxProcessor
29
+ from .embeddings import AdaptiveEmbedding, PositionalEmbedding
30
+ from .generation import GenerationConfigCustom
31
+ from .multi_head_gpt2 import GPT2LMMultiHeadModel
32
+
33
+ logger = logging.get_logger("transformers")
34
+
35
+
36
+ class LMRescorerLogitsProcessor(LogitsProcessor):
37
+ """Logits Processor to rescore the next token scores with a language model."""
38
+
39
+ def __init__(self, lm_weight: float, lm_model: PreTrainedModel, device: torch.device):
40
+ super().__init__()
41
+ self.lm_model = lm_model.to(device)
42
+ self.lm_weight = lm_weight
43
+ # self.past_key_values = None
44
+
45
+ @staticmethod
46
+ def analyze_predictions(scores, lm_scores, next_token_scores, input_ids, k=10, tokenizer="Lakoc/ted_uni500"):
47
+ from transformers import AutoTokenizer
48
+
49
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
50
+ best_att_ids = scores.topk(k=k, dim=1)
51
+ best_ctc_ids = lm_scores.topk(k=k, dim=1)
52
+ best_ids = next_token_scores.topk(k=k, dim=1)
53
+
54
+ def print_prediction(best_ids, name):
55
+ new_tensor = torch.zeros((best_ids.indices.shape[0], best_ids.indices.shape[1] * 2), dtype=torch.long)
56
+ new_tensor[:, 0::2] = best_ids.indices
57
+ new_tensor[:, 1::2] = 1
58
+ print(f"{name}:")
59
+ for index, (next_ids, scores) in enumerate(zip(tokenizer.batch_decode(new_tensor), best_ids.values)):
60
+ print(f"HYP {index}:\n{next_ids} {scores}")
61
+
62
+ print(f"PREFIX:")
63
+ for index, prefix in enumerate(tokenizer.batch_decode(input_ids)):
64
+ print(f"HYP {index}:\n{prefix}")
65
+ print_prediction(best_att_ids, "ACCUSTIC_SCORES")
66
+ print()
67
+ print_prediction(best_ctc_ids, "LM_SCORES")
68
+ print()
69
+ print_prediction(best_ids, "NEXT_TOKEN_SCORES")
70
+ print()
71
+
72
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
73
+ # TODO: KarelB: Can you implement the past_key_values logic?
74
+ outputs = self.lm_model(
75
+ input_ids,
76
+ # input_ids[:, -1]
77
+ # past_key_values=self.past_key_values,
78
+ # use_cache=True
79
+ )
80
+ # self.past_key_values = outputs.past_key_values
81
+ lm_scores = torch.nn.functional.log_softmax(outputs.logits[:, -1, :], dim=-1)
82
+ next_token_scores = scores + self.lm_weight * lm_scores
83
+ # self.analyze_predictions(scores, lm_scores, next_token_scores, input_ids)
84
+ return next_token_scores
85
+
86
+
87
+ def wav2vec2_forward_hidden_return_hook(_: PreTrainedModel, __: Any, kwargs):
88
+ kwargs["output_hidden_states"] = True
89
+
90
+
91
+ @dataclass
92
+ class Seq2SeqLMOutputLosses(Seq2SeqLMOutput):
93
+ enc_loss: Optional[torch.FloatTensor] = None
94
+ dec_loss: Optional[torch.FloatTensor] = None
95
+ encoder_logits: Optional[torch.FloatTensor] = None
96
+
97
+
98
+ def wav2vec2_for_ctc_forward_hook(model: CustomAutoModelForCTC, input: Any, output: CausalLMOutput):
99
+ if "hidden_states" in output:
100
+ output.last_hidden_state = output.hidden_states[-1]
101
+
102
+
103
+ class JointCTCAttentionEncoderDecoder(SpeechEncoderDecoderModel):
104
+ """Custom model for CTC+Attention loss based on the ESPNet architecture"""
105
+
106
+ config_class = JointCTCAttentionEncoderDecoderConfig
107
+ base_model_prefix = "joint_aed_ctc_speech-encoder-decoder"
108
+
109
+ def __init__(
110
+ self,
111
+ config: Optional[PretrainedConfig] = None,
112
+ encoder: Optional[PreTrainedModel] = None,
113
+ decoder: Optional[PreTrainedModel] = None,
114
+ ):
115
+ if config is None and (encoder is None or decoder is None):
116
+ raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
117
+ if config is None:
118
+ config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
119
+ else:
120
+ if not isinstance(config, self.config_class):
121
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
122
+
123
+ if config.decoder.cross_attention_hidden_size is not None:
124
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
125
+ raise ValueError(
126
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
127
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
128
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
129
+ " `config.encoder.hidden_size`."
130
+ )
131
+
132
+ # initialize with config
133
+ # make sure input & output embeddings is not tied
134
+ config.tie_word_embeddings = False
135
+ super(SpeechEncoderDecoderModel, self).__init__(config)
136
+
137
+ if encoder is None:
138
+ encoder = CustomAutoModelForCTC.from_config(config.encoder)
139
+ encoder.register_forward_hook(wav2vec2_for_ctc_forward_hook)
140
+ encoder.register_forward_pre_hook(wav2vec2_forward_hidden_return_hook, with_kwargs=True)
141
+ if decoder is None:
142
+ decoder = AutoModelForCausalLM.from_config(config.decoder)
143
+
144
+ self.encoder = encoder
145
+ self.decoder = decoder
146
+
147
+ if self.encoder.config.to_dict() != self.config.encoder.to_dict():
148
+ logger.warning(
149
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
150
+ f" {self.config.encoder}"
151
+ )
152
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
153
+ logger.warning(
154
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
155
+ f" {self.config.decoder}"
156
+ )
157
+
158
+ # make sure that the individual model's config refers to the shared config
159
+ # so that the updates to the config will be synced
160
+ self.encoder.config = self.config.encoder
161
+ self.decoder.config = self.config.decoder
162
+
163
+ # get encoder output hidden size
164
+ self.encoder_output_dim = getattr(config.encoder, "output_hidden_size", config.encoder.hidden_size)
165
+ if (
166
+ self.encoder_output_dim != self.decoder.config.hidden_size
167
+ and self.decoder.config.cross_attention_hidden_size is None
168
+ ):
169
+ # encoder outputs might need to be projected to different dimension for decoder
170
+ self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
171
+
172
+ if self.encoder.get_output_embeddings() is not None:
173
+ raise ValueError(
174
+ f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
175
+ )
176
+ self.enc_loss_weight = config.ctc_weight
177
+ self.dec_loss_weight = 1 - config.ctc_weight
178
+ self.lsm_factor = config.lsm_factor
179
+
180
+ if config.shared_lm_head:
181
+ self.encoder.lm_head.weight = self.decoder.lm_head.weight
182
+
183
+ if (hasattr(config, "decoder_pos_emb_fixed") and config.decoder_pos_emb_fixed) or (
184
+ hasattr(config.decoder, "pos_emb_fixed") and config.decoder.pos_emb_fixed
185
+ ):
186
+ self.decoder.transformer.wte = AdaptiveEmbedding(
187
+ n_token=config.decoder.vocab_size,
188
+ d_embed=config.decoder.hidden_size,
189
+ d_proj=config.decoder.hidden_size,
190
+ cutoffs=[],
191
+ )
192
+ self.decoder.transformer.wpe = PositionalEmbedding(demb=config.decoder.hidden_size)
193
+ self.decoder.post_init()
194
+
195
+ self.encoder_logits = None
196
+ self.encoder_output_lens = None
197
+
198
+ @classmethod
199
+ def from_encoder_decoder_pretrained(
200
+ cls,
201
+ encoder_pretrained_model_name_or_path: str = None,
202
+ decoder_pretrained_model_name_or_path: str = None,
203
+ *model_args,
204
+ **kwargs,
205
+ ) -> PreTrainedModel:
206
+ kwargs_encoder = {
207
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
208
+ }
209
+
210
+ kwargs_decoder = {
211
+ argument[len("decoder_") :]: value
212
+ for argument, value in kwargs.items()
213
+ if argument.startswith("decoder_") and argument != "decoder_start_token_id"
214
+ }
215
+
216
+ # remove encoder, decoder kwargs from kwargs
217
+ for key in kwargs_encoder.keys():
218
+ del kwargs["encoder_" + key]
219
+ for key in kwargs_decoder.keys():
220
+ del kwargs["decoder_" + key]
221
+
222
+ # Load and initialize the encoder and decoder
223
+ # The distinction between encoder and decoder at the model level is made
224
+ # by the value of the flag `is_decoder` that we need to set correctly.
225
+ encoder = kwargs_encoder.pop("model", None)
226
+ if encoder is None:
227
+ if encoder_pretrained_model_name_or_path is None:
228
+ raise ValueError(
229
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
230
+ "to be defined."
231
+ )
232
+
233
+ if "config" not in kwargs_encoder:
234
+ encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
235
+ encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
236
+ )
237
+
238
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
239
+ logger.info(
240
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
241
+ "from a decoder model. Cross-attention and casual mask are disabled."
242
+ )
243
+ encoder_config.is_decoder = False
244
+ encoder_config.add_cross_attention = False
245
+
246
+ kwargs_encoder["config"] = encoder_config
247
+
248
+ encoder = CustomAutoModelForCTC.from_pretrained(
249
+ encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
250
+ )
251
+ encoder.register_forward_hook(wav2vec2_for_ctc_forward_hook)
252
+
253
+ decoder = kwargs_decoder.pop("model", None)
254
+ if decoder is None:
255
+ if decoder_pretrained_model_name_or_path is None:
256
+ raise ValueError(
257
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
258
+ "to be defined."
259
+ )
260
+
261
+ if "config" not in kwargs_decoder:
262
+ decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
263
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
264
+ )
265
+
266
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
267
+ logger.info(
268
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
269
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
270
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
271
+ )
272
+ decoder_config.is_decoder = True
273
+ decoder_config.add_cross_attention = True
274
+
275
+ kwargs_decoder["config"] = decoder_config
276
+
277
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
278
+ logger.warning(
279
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
280
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
281
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
282
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
283
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
284
+ )
285
+
286
+ decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
287
+
288
+ # instantiate config with corresponding kwargs
289
+ config = JointCTCAttentionEncoderDecoderConfig.from_encoder_decoder_configs(
290
+ encoder.config, decoder.config, **kwargs
291
+ )
292
+
293
+ # make sure input & output embeddings is not tied
294
+ config.tie_word_embeddings = False
295
+ return cls(encoder=encoder, decoder=decoder, config=config)
296
+
297
+ def forward(
298
+ self,
299
+ inputs: Optional[torch.FloatTensor] = None,
300
+ attention_mask: Optional[torch.FloatTensor] = None,
301
+ decoder_input_ids: Optional[torch.LongTensor] = None,
302
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
303
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
304
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
305
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
306
+ labels: Optional[torch.LongTensor] = None,
307
+ use_cache: Optional[bool] = None,
308
+ output_attentions: Optional[bool] = None,
309
+ output_hidden_states: Optional[bool] = None,
310
+ input_values: Optional[torch.FloatTensor] = None,
311
+ input_features: Optional[torch.FloatTensor] = None,
312
+ return_dict: Optional[bool] = None,
313
+ **kwargs,
314
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutputLosses]:
315
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
316
+
317
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
318
+
319
+ kwargs_decoder = {
320
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
321
+ }
322
+
323
+ if encoder_outputs is None:
324
+ if inputs is None:
325
+ if input_values is not None and input_features is not None:
326
+ raise ValueError("You cannot specify both input_values and input_features at the same time")
327
+ elif input_values is not None:
328
+ inputs = input_values
329
+ elif input_features is not None:
330
+ inputs = input_features
331
+ else:
332
+ raise ValueError("You have to specify either input_values or input_features")
333
+
334
+ encoder_outputs = self.encoder(
335
+ inputs,
336
+ attention_mask=attention_mask,
337
+ output_attentions=output_attentions,
338
+ output_hidden_states=output_hidden_states,
339
+ return_dict=return_dict,
340
+ labels=labels,
341
+ **kwargs_encoder,
342
+ )
343
+ elif isinstance(encoder_outputs, tuple):
344
+ encoder_outputs = CausalLMOutput(*encoder_outputs)
345
+
346
+ encoder_hidden_states = encoder_outputs.last_hidden_state
347
+
348
+ # optionally project encoder_hidden_states
349
+ if (
350
+ self.encoder_output_dim != self.decoder.config.hidden_size
351
+ and self.decoder.config.cross_attention_hidden_size is None
352
+ ):
353
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
354
+
355
+ # compute correct encoder attention mask
356
+ if attention_mask is not None:
357
+ encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
358
+ encoder_hidden_states.shape[1], attention_mask
359
+ )
360
+ else:
361
+ encoder_attention_mask = None
362
+
363
+ if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
364
+ decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
365
+
366
+ # Decode
367
+ decoder_outputs = self.decoder(
368
+ input_ids=decoder_input_ids,
369
+ attention_mask=decoder_attention_mask,
370
+ encoder_hidden_states=encoder_hidden_states,
371
+ encoder_attention_mask=encoder_attention_mask,
372
+ inputs_embeds=decoder_inputs_embeds,
373
+ output_attentions=output_attentions,
374
+ output_hidden_states=True
375
+ if hasattr(self.decoder, "head_weights") and len(self.decoder.head_weights) > 1
376
+ else output_hidden_states,
377
+ use_cache=use_cache,
378
+ past_key_values=past_key_values,
379
+ return_dict=return_dict,
380
+ **kwargs_decoder,
381
+ )
382
+
383
+ # Compute loss independent from decoder (as some shift the logits inside them)
384
+ loss = enc_loss = dec_loss = None
385
+
386
+ if labels is not None:
387
+ loss_fct = CrossEntropyLoss(label_smoothing=self.lsm_factor)
388
+ enc_loss = encoder_outputs.loss if return_dict else encoder_outputs[0]
389
+ if isinstance(self.decoder, GPT2LMMultiHeadModel) and len(self.decoder.head_weights) > 1:
390
+ dec_loss = torch.zeros_like(enc_loss)
391
+ lm_logits_per_layer = []
392
+ for index, lm_head, lm_weight in zip(
393
+ [*self.decoder.head_locations, -1],
394
+ [*self.decoder.additional_lm_heads, self.decoder.lm_head],
395
+ self.decoder.head_weights,
396
+ ):
397
+ lm_logits = lm_head(decoder_outputs.hidden_states[index])
398
+ dec_loss += lm_weight * loss_fct(
399
+ lm_logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1)
400
+ )
401
+ lm_logits_per_layer.append(lm_logits)
402
+ if self.decoder.config.average_logits:
403
+ decoder_outputs.logits = torch.matmul(
404
+ torch.stack(lm_logits_per_layer).T,
405
+ torch.tensor(self.decoder.head_weights, device=lm_logits_per_layer[-1].device),
406
+ ).T
407
+
408
+ else:
409
+ dec_logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
410
+ dec_loss = loss_fct(dec_logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
411
+ loss = self.enc_loss_weight * enc_loss + self.dec_loss_weight * dec_loss
412
+
413
+ if not return_dict:
414
+ if loss is not None:
415
+ return (loss,) + decoder_outputs + encoder_outputs
416
+ else:
417
+ return decoder_outputs + encoder_outputs
418
+
419
+ return Seq2SeqLMOutputLosses(
420
+ loss=loss,
421
+ enc_loss=enc_loss,
422
+ dec_loss=dec_loss,
423
+ logits=decoder_outputs.logits,
424
+ past_key_values=decoder_outputs.past_key_values,
425
+ decoder_hidden_states=decoder_outputs.hidden_states,
426
+ decoder_attentions=decoder_outputs.attentions,
427
+ cross_attentions=decoder_outputs.cross_attentions,
428
+ encoder_last_hidden_state=encoder_hidden_states,
429
+ encoder_hidden_states=encoder_outputs.hidden_states,
430
+ encoder_attentions=encoder_outputs.attentions,
431
+ encoder_logits=encoder_outputs.logits,
432
+ )
433
+
434
+ def _get_logits_processor(
435
+ self,
436
+ generation_config: GenerationConfigCustom,
437
+ input_ids_seq_length: int,
438
+ encoder_input_ids: torch.LongTensor,
439
+ prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
440
+ logits_processor: Optional[LogitsProcessorList],
441
+ model_kwargs: Optional[Dict[str, Any]] = None,
442
+ negative_prompt_ids: Optional[torch.Tensor] = None,
443
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
444
+ ) -> LogitsProcessorList:
445
+ # pylint: disable=no-member
446
+ processors = super()._get_logits_processor(
447
+ generation_config,
448
+ input_ids_seq_length,
449
+ encoder_input_ids,
450
+ prefix_allowed_tokens_fn,
451
+ logits_processor,
452
+ model_kwargs,
453
+ negative_prompt_ids,
454
+ negative_prompt_attention_mask,
455
+ )
456
+ if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
457
+ if generation_config.num_beams <= 1:
458
+ processors.append(LogSoftmaxProcessor())
459
+ self.ctc_rescorer = CTCRescorerLogitsProcessor(
460
+ self.encoder_logits,
461
+ self.encoder_output_lens,
462
+ self.generation_config.pad_token_id,
463
+ self.generation_config.eos_token_id,
464
+ self.generation_config.ctc_margin,
465
+ self.generation_config.ctc_weight,
466
+ self.generation_config.num_beams,
467
+ self.generation_config.space_token_id if hasattr(self.generation_config, "space_token_id") else None,
468
+ self.generation_config.apply_eos_space_trick
469
+ if hasattr(self.generation_config, "apply_eos_space_trick")
470
+ else False,
471
+ self.generation_config.eos_space_trick_weight
472
+ if hasattr(self.generation_config, "eos_space_trick_weight")
473
+ else 0.0,
474
+ )
475
+ processors.append(self.ctc_rescorer)
476
+ if hasattr(generation_config, "lm_weight") and generation_config.lm_weight > 0:
477
+ if not hasattr(generation_config, "lm_model"):
478
+ raise ValueError("If `lm_weight` is specified, make sure that `lm_model` is defined.")
479
+ processors.append(
480
+ LMRescorerLogitsProcessor(generation_config.lm_weight, generation_config.lm_model, device=self.device)
481
+ )
482
+ return processors
483
+
484
+ def _prepare_encoder_decoder_kwargs_for_generation(
485
+ self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
486
+ ) -> Dict[str, Any]:
487
+ self.encoder_output_lens = self.encoder._get_feat_extract_output_lengths(
488
+ model_kwargs["attention_mask"].sum(dim=1)
489
+ )
490
+ # pylint: disable=E1101
491
+ model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation(
492
+ inputs_tensor, model_kwargs, model_input_name
493
+ )
494
+ self.encoder_logits = model_kwargs["encoder_outputs"].logits
495
+ return model_kwargs
496
+
497
+ @staticmethod
498
+ def _expand_inputs_for_generation(
499
+ expand_size: int = 1,
500
+ is_encoder_decoder: bool = False,
501
+ input_ids: Optional[torch.LongTensor] = None,
502
+ **model_kwargs,
503
+ ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
504
+ """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
505
+
506
+ def _expand_dict_for_generation(dict_to_expand):
507
+ for key in dict_to_expand:
508
+ if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor) and key != "loss":
509
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
510
+ return dict_to_expand
511
+
512
+ if input_ids is not None:
513
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
514
+
515
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
516
+
517
+ if is_encoder_decoder:
518
+ if model_kwargs.get("encoder_outputs") is None:
519
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
520
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
521
+ model_kwargs["encoder_outputs"].last_hidden_state = model_kwargs[
522
+ "encoder_outputs"
523
+ ].last_hidden_state.repeat_interleave(expand_size, dim=0)
524
+
525
+ return input_ids, model_kwargs
526
+
527
+ @torch.no_grad()
528
+ def generate(
529
+ self,
530
+ inputs: Optional[torch.Tensor] = None,
531
+ generation_config: Optional[GenerationConfigCustom] = None,
532
+ logits_processor: Optional[LogitsProcessorList] = None,
533
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
534
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
535
+ synced_gpus: Optional[bool] = None,
536
+ assistant_model: Optional["PreTrainedModel"] = None,
537
+ streamer: Optional["BaseStreamer"] = None,
538
+ **kwargs,
539
+ ) -> Union[GenerateOutput, torch.LongTensor]:
540
+ if "encoder_outputs" in kwargs:
541
+ self.encoder_logits = kwargs["encoder_outputs"].logits
542
+ self.encoder_output_lens = self.encoder._get_feat_extract_output_lengths(
543
+ kwargs["attention_mask"].sum(dim=1)
544
+ )
545
+ # pylint: disable=E1101
546
+ output = super().generate(
547
+ inputs,
548
+ generation_config,
549
+ logits_processor,
550
+ stopping_criteria,
551
+ prefix_allowed_tokens_fn,
552
+ synced_gpus,
553
+ assistant_model,
554
+ streamer,
555
+ **kwargs,
556
+ )
557
+ self.encoder_logits = None
558
+ self.encoder_output_lens = None
559
+ return output
560
+
561
+
562
+ AutoConfig.register("joint_aed_ctc_speech-encoder-decoder", JointCTCAttentionEncoderDecoderConfig)
563
+ AutoModelForSpeechSeq2Seq.register(JointCTCAttentionEncoderDecoderConfig, JointCTCAttentionEncoderDecoder)
multi_head_gpt2.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
9
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
10
+
11
+
12
+ class GPT2MultiHeadConfig(GPT2Config):
13
+ model_type = "gpt2-multi-head"
14
+
15
+ def __init__(
16
+ self,
17
+ head_locations=None,
18
+ head_weights=None,
19
+ tie_additional_weights=False,
20
+ average_logits=False,
21
+ *args,
22
+ **kwargs,
23
+ ):
24
+ super().__init__(*args, **kwargs)
25
+ self.head_locations = head_locations
26
+ self.head_weights = head_weights
27
+ self.tie_additional_weights = tie_additional_weights
28
+ self.average_logits = average_logits
29
+
30
+
31
+ class GPT2LMMultiHeadModel(GPT2LMHeadModel):
32
+ config_class = GPT2MultiHeadConfig
33
+
34
+ def __init__(self, config: GPT2MultiHeadConfig):
35
+ super().__init__(config)
36
+ if config.head_locations is not None:
37
+ if not len(config.head_locations) + 1 == len(config.head_weights):
38
+ raise ValueError("The number of head locations should be equal to the number of head weights minus 1")
39
+ self.head_locations = config.head_locations
40
+ self.additional_lm_heads = nn.ModuleList(
41
+ [nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in config.head_locations]
42
+ )
43
+ self.head_weights = config.head_weights
44
+ else:
45
+ self.head_locations = []
46
+ self.additional_lm_heads = nn.ModuleList([])
47
+ self.head_weights = [1.0]
48
+ self.post_init()
49
+
50
+ def tie_weights(self):
51
+ """
52
+ Tie the weights between the input embeddings and the output embeddings.
53
+
54
+ If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
55
+ weights instead.
56
+ """
57
+ super().tie_weights()
58
+ if hasattr(self, "additional_lm_heads") and getattr(self.config, "tie_additional_weights", False):
59
+ input_embeddings = self.get_input_embeddings()
60
+ for classifier in self.additional_lm_heads:
61
+ if self.config.torchscript:
62
+ classifier.weight = nn.Parameter(input_embeddings.weight.clone())
63
+ else:
64
+ classifier.weight = input_embeddings.weight
65
+
66
+ if getattr(classifier, "bias", None) is not None:
67
+ classifier.bias.data = nn.functional.pad(
68
+ classifier.bias.data,
69
+ (
70
+ 0,
71
+ classifier.weight.shape[0] - classifier.bias.shape[0],
72
+ ),
73
+ "constant",
74
+ 0,
75
+ )
76
+ if hasattr(classifier, "out_features") and hasattr(input_embeddings, "num_embeddings"):
77
+ classifier.out_features = input_embeddings.num_embeddings
78
+
79
+ def forward(
80
+ self,
81
+ input_ids: Optional[torch.LongTensor] = None,
82
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
83
+ attention_mask: Optional[torch.FloatTensor] = None,
84
+ token_type_ids: Optional[torch.LongTensor] = None,
85
+ position_ids: Optional[torch.LongTensor] = None,
86
+ head_mask: Optional[torch.FloatTensor] = None,
87
+ inputs_embeds: Optional[torch.FloatTensor] = None,
88
+ encoder_hidden_states: Optional[torch.Tensor] = None,
89
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
90
+ labels: Optional[torch.LongTensor] = None,
91
+ use_cache: Optional[bool] = None,
92
+ output_attentions: Optional[bool] = None,
93
+ output_hidden_states: Optional[bool] = None,
94
+ return_dict: Optional[bool] = None,
95
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
96
+ r"""
97
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
98
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
99
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
100
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
101
+ """
102
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
103
+
104
+ transformer_outputs = self.transformer(
105
+ input_ids,
106
+ past_key_values=past_key_values,
107
+ attention_mask=attention_mask,
108
+ token_type_ids=token_type_ids,
109
+ position_ids=position_ids,
110
+ head_mask=head_mask,
111
+ inputs_embeds=inputs_embeds,
112
+ encoder_hidden_states=encoder_hidden_states,
113
+ encoder_attention_mask=encoder_attention_mask,
114
+ use_cache=use_cache,
115
+ output_attentions=output_attentions,
116
+ output_hidden_states=True,
117
+ return_dict=return_dict,
118
+ )
119
+ hidden_states = transformer_outputs[2]
120
+
121
+ # Set device for model parallelism
122
+ if self.model_parallel:
123
+ torch.cuda.set_device(self.transformer.first_device)
124
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
125
+
126
+ lm_logits = self.lm_head(hidden_states[-1])
127
+ loss = None
128
+ if labels is not None:
129
+ loss = torch.tensor(0.0, device=hidden_states[-1].device)
130
+ lm_logits = []
131
+ loss_fct = CrossEntropyLoss()
132
+
133
+ for index, lm_head, lm_weight in zip(
134
+ [*self.head_locations, -1],
135
+ [*self.additional_lm_heads, self.lm_head],
136
+ self.head_weights,
137
+ ):
138
+ lm_logits.append(lm_head(hidden_states[index]))
139
+ # Shift so that tokens < n predict n
140
+ shift_logits = lm_logits[-1][..., :-1, :].contiguous()
141
+ shift_labels = labels[..., 1:].contiguous()
142
+ # Flatten the tokens
143
+ loss += lm_weight * loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
144
+
145
+ if self.config.average_logits:
146
+ lm_logits = (torch.vstack(lm_logits) * torch.tensor(self.head_weights)).mean(dim=0)
147
+ else:
148
+ lm_logits = lm_logits[-1]
149
+ if not return_dict:
150
+ output = (lm_logits,) + transformer_outputs[1:]
151
+ return ((loss,) + output) if loss is not None else output
152
+
153
+ return CausalLMOutputWithCrossAttentions(
154
+ loss=loss,
155
+ logits=lm_logits,
156
+ past_key_values=transformer_outputs.past_key_values,
157
+ hidden_states=transformer_outputs.hidden_states,
158
+ attentions=transformer_outputs.attentions,
159
+ cross_attentions=transformer_outputs.cross_attentions,
160
+ )
residual_clasiffier_gpt2.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
9
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
10
+
11
+
12
+ class GPT2ResidualsLMHeadConfig(GPT2Config):
13
+ model_type = "gpt2-residuals-head"
14
+
15
+ def __init__(self, connected_residuals=None, *args, **kwargs):
16
+ super().__init__(*args, **kwargs)
17
+ self.connected_residuals = connected_residuals
18
+
19
+
20
+ class GPT2ResidualsLMHeadModel(GPT2LMHeadModel):
21
+ config_class = GPT2ResidualsLMHeadConfig
22
+
23
+ def __init__(self, config: GPT2ResidualsLMHeadConfig):
24
+ super().__init__(config)
25
+ self.connected_residuals = config.connected_residuals
26
+ self.lm_head = nn.Linear(config.n_embd * len(self.connected_residuals), config.vocab_size, bias=False)
27
+ self.post_init()
28
+
29
+ def forward(
30
+ self,
31
+ input_ids: Optional[torch.LongTensor] = None,
32
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
33
+ attention_mask: Optional[torch.FloatTensor] = None,
34
+ token_type_ids: Optional[torch.LongTensor] = None,
35
+ position_ids: Optional[torch.LongTensor] = None,
36
+ head_mask: Optional[torch.FloatTensor] = None,
37
+ inputs_embeds: Optional[torch.FloatTensor] = None,
38
+ encoder_hidden_states: Optional[torch.Tensor] = None,
39
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
40
+ labels: Optional[torch.LongTensor] = None,
41
+ use_cache: Optional[bool] = None,
42
+ output_attentions: Optional[bool] = None,
43
+ output_hidden_states: Optional[bool] = None,
44
+ return_dict: Optional[bool] = None,
45
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
46
+ r"""
47
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
48
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
49
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
50
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
51
+ """
52
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
53
+
54
+ transformer_outputs = self.transformer(
55
+ input_ids,
56
+ past_key_values=past_key_values,
57
+ attention_mask=attention_mask,
58
+ token_type_ids=token_type_ids,
59
+ position_ids=position_ids,
60
+ head_mask=head_mask,
61
+ inputs_embeds=inputs_embeds,
62
+ encoder_hidden_states=encoder_hidden_states,
63
+ encoder_attention_mask=encoder_attention_mask,
64
+ use_cache=use_cache,
65
+ output_attentions=output_attentions,
66
+ output_hidden_states=True,
67
+ return_dict=return_dict,
68
+ )
69
+ hidden_states = transformer_outputs[2]
70
+
71
+ # Set device for model parallelism
72
+ if self.model_parallel:
73
+ torch.cuda.set_device(self.transformer.first_device)
74
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
75
+
76
+ hidden_states = torch.concat([hidden_states[index] for index in self.connected_residuals], dim=-1)
77
+ lm_logits = self.lm_head(hidden_states)
78
+
79
+ loss = None
80
+ if labels is not None:
81
+ # Shift so that tokens < n predict n
82
+ shift_logits = lm_logits[..., :-1, :].contiguous()
83
+ shift_labels = labels[..., 1:].contiguous()
84
+ # Flatten the tokens
85
+ loss_fct = CrossEntropyLoss()
86
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
87
+
88
+ if not return_dict:
89
+ output = (lm_logits,) + transformer_outputs[1:]
90
+ return ((loss,) + output) if loss is not None else output
91
+
92
+ return CausalLMOutputWithCrossAttentions(
93
+ loss=loss,
94
+ logits=lm_logits,
95
+ past_key_values=transformer_outputs.past_key_values,
96
+ hidden_states=transformer_outputs.hidden_states,
97
+ attentions=transformer_outputs.attentions,
98
+ cross_attentions=transformer_outputs.cross_attentions,
99
+ )