prompteus commited on
Commit
d715b2e
1 Parent(s): 8b93dee

Create model_class.py

Browse files
Files changed (1) hide show
  1. model_class.py +158 -0
model_class.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ from typing import Optional, Tuple, Union
4
+ from transformers.modeling_outputs import Seq2SeqLMOutput
5
+ from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
6
+ from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
7
+
8
+
9
+ class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):
10
+
11
+ def forward(
12
+ self,
13
+ input_features: Optional[torch.FloatTensor] = None,
14
+ attention_mask: Optional[torch.LongTensor] = None,
15
+ decoder_input_ids: Optional[torch.LongTensor] = None,
16
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
17
+ head_mask: Optional[torch.Tensor] = None,
18
+ decoder_head_mask: Optional[torch.Tensor] = None,
19
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
20
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
21
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
22
+ decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
23
+ labels: Optional[torch.LongTensor] = None,
24
+ use_cache: Optional[bool] = None,
25
+ output_attentions: Optional[bool] = None,
26
+ output_hidden_states: Optional[bool] = None,
27
+ return_dict: Optional[bool] = None,
28
+ forced_ac_decoder_ids: Optional[torch.LongTensor] = None, # added to be ignored when passed from trainer
29
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
30
+ return super().forward(
31
+ input_features=input_features,
32
+ attention_mask=attention_mask,
33
+ decoder_input_ids=decoder_input_ids,
34
+ decoder_attention_mask=decoder_attention_mask,
35
+ head_mask=head_mask,
36
+ decoder_head_mask=decoder_head_mask,
37
+ cross_attn_head_mask=cross_attn_head_mask,
38
+ encoder_outputs=encoder_outputs,
39
+ past_key_values=past_key_values,
40
+ decoder_inputs_embeds=decoder_inputs_embeds,
41
+ labels=labels,
42
+ use_cache=use_cache,
43
+ output_attentions=output_attentions,
44
+ output_hidden_states=output_hidden_states,
45
+ return_dict=return_dict,
46
+ )
47
+
48
+ # copy-pasted and adapted from transformers.WhisperForConditionalGeneration.generate
49
+ def generate(
50
+ self,
51
+ inputs: Optional[torch.Tensor] = None,
52
+ forced_ac_decoder_ids: Optional[torch.Tensor] = None,
53
+ generation_config=None,
54
+ logits_processor=None,
55
+ stopping_criteria=None,
56
+ prefix_allowed_tokens_fn=None,
57
+ synced_gpus=False,
58
+ return_timestamps=None,
59
+ task="transcribe",
60
+ language="english",
61
+ **kwargs,
62
+ ):
63
+ if generation_config is None:
64
+ generation_config = self.generation_config
65
+
66
+ if return_timestamps is not None:
67
+ if not hasattr(generation_config, "no_timestamps_token_id"):
68
+ raise ValueError(
69
+ "You are trying to return timestamps, but the generation config is not properly set."
70
+ "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`."
71
+ "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
72
+ )
73
+
74
+ generation_config.return_timestamps = return_timestamps
75
+ else:
76
+ generation_config.return_timestamps = False
77
+
78
+ if language is not None:
79
+ generation_config.language = language
80
+ if task is not None:
81
+ generation_config.task = task
82
+
83
+ forced_decoder_ids = []
84
+ if task is not None or language is not None:
85
+ if hasattr(generation_config, "language"):
86
+ if generation_config.language in generation_config.lang_to_id.keys():
87
+ language_token = generation_config.language
88
+ elif generation_config.language in TO_LANGUAGE_CODE.keys():
89
+ language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
90
+ else:
91
+ raise ValueError(
92
+ f"Unsupported language: {language}. Language should be one of:"
93
+ f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}."
94
+ )
95
+ forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
96
+ else:
97
+ forced_decoder_ids.append((1, None)) # automatically detect the language
98
+
99
+ if hasattr(generation_config, "task"):
100
+ if generation_config.task in TASK_IDS:
101
+ forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
102
+ else:
103
+ raise ValueError(
104
+ f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
105
+ )
106
+ else:
107
+ forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
108
+ if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
109
+ idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
110
+ forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
111
+
112
+ # Legacy code for backward compatibility
113
+ elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
114
+ forced_decoder_ids = self.config.forced_decoder_ids
115
+ elif (
116
+ hasattr(self.generation_config, "forced_decoder_ids")
117
+ and self.generation_config.forced_decoder_ids is not None
118
+ ):
119
+ forced_decoder_ids = self.generation_config.forced_decoder_ids
120
+
121
+ if generation_config.return_timestamps:
122
+ logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
123
+
124
+ decoder_input_ids = None
125
+
126
+ if len(forced_decoder_ids) > 0:
127
+ # get the token sequence coded in forced_decoder_ids
128
+ forced_decoder_ids.sort()
129
+ if min(forced_decoder_ids)[0] != 0:
130
+ forced_decoder_ids = [(0, self.config.decoder_start_token_id)] + forced_decoder_ids
131
+
132
+ position_indices, decoder_input_ids = zip(*forced_decoder_ids)
133
+ assert tuple(position_indices) == tuple(range(len(position_indices))), "forced_decoder_ids is not a (continuous) prefix, we can't handle that"
134
+
135
+ device = self.get_decoder().device
136
+
137
+ if forced_ac_decoder_ids is None:
138
+ forced_ac_decoder_ids = torch.tensor([[]], device=device, dtype=torch.long)
139
+
140
+ # enrich every sample's forced_ac_decoder_ids with Whisper's forced_decoder_ids
141
+ batch_size = forced_ac_decoder_ids.shape[0]
142
+ fluff_len = len(decoder_input_ids)
143
+ decoder_input_ids = torch.tensor(decoder_input_ids, device=device, dtype=torch.long)
144
+ decoder_input_ids = decoder_input_ids.expand((batch_size, fluff_len))
145
+ decoder_input_ids = torch.cat([decoder_input_ids, forced_ac_decoder_ids], dim=1)
146
+
147
+ generation_config.forced_decoder_ids = forced_decoder_ids
148
+
149
+ return super(transformers.WhisperPreTrainedModel, self).generate( # changed by adam (calling grandparent)
150
+ inputs,
151
+ generation_config,
152
+ logits_processor,
153
+ stopping_criteria,
154
+ prefix_allowed_tokens_fn,
155
+ synced_gpus,
156
+ decoder_input_ids=decoder_input_ids,
157
+ **kwargs,
158
+ )