sdadas commited on
Commit
eabfded
·
verified ·
1 Parent(s): 9d5c42c

Upload modeling_roberta.py

Browse files
Files changed (1) hide show
  1. modeling_roberta.py +197 -0
modeling_roberta.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Unpack
2
+ import torch
3
+ from transformers import (
4
+ RobertaModel,
5
+ Cache,
6
+ EncoderDecoderCache,
7
+ DynamicCache,
8
+ DataCollatorWithFlattening,
9
+ RobertaForMaskedLM,
10
+ RobertaForSequenceClassification,
11
+ RobertaForTokenClassification,
12
+ RobertaForQuestionAnswering,
13
+ RobertaForMultipleChoice,
14
+ RobertaForCausalLM
15
+ )
16
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
17
+ from transformers.utils import TransformersKwargs
18
+
19
+
20
+ def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
21
+ collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True)
22
+ features = collator([{"input_ids": i[a.bool()].tolist()} for i, a in zip(input_ids, attention_mask)])
23
+ return features
24
+
25
+
26
+ def _pad_output(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int,) -> torch.Tensor:
27
+ if inputs.dim() == 3:
28
+ inputs = inputs.squeeze()
29
+ if inputs.dim() == 1:
30
+ output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
31
+ output[indices] = inputs
32
+ padded_inputs = output.view(batch, seqlen)
33
+ else:
34
+ _, *rest = inputs.shape
35
+ output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
36
+ output[indices] = inputs
37
+ padded_inputs = output.view(batch, seqlen, *rest)
38
+ return padded_inputs
39
+
40
+
41
+ class UnpadRobertaModel(RobertaModel):
42
+ _no_split_modules = ["RobertaEmbeddings", "RobertaLayer"]
43
+
44
+ def __init__(self, config, add_pooling_layer=True):
45
+ super().__init__(config, add_pooling_layer=add_pooling_layer)
46
+
47
+ def forward(
48
+ self,
49
+ input_ids: torch.Tensor | None = None,
50
+ attention_mask: torch.Tensor | None = None,
51
+ token_type_ids: torch.Tensor | None = None,
52
+ position_ids: torch.Tensor | None = None,
53
+ inputs_embeds: torch.Tensor | None = None,
54
+ encoder_hidden_states: torch.Tensor | None = None,
55
+ encoder_attention_mask: torch.Tensor | None = None,
56
+ past_key_values: Cache | None = None,
57
+ use_cache: bool | None = None,
58
+ cache_position: torch.Tensor | None = None,
59
+ **kwargs: Unpack[TransformersKwargs],
60
+ ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
61
+ if self.config.is_decoder:
62
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
63
+ else:
64
+ use_cache = False
65
+
66
+ if use_cache and past_key_values is None:
67
+ past_key_values = (
68
+ EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
69
+ if encoder_hidden_states is not None or self.config.is_encoder_decoder
70
+ else DynamicCache(config=self.config)
71
+ )
72
+
73
+ if (input_ids is None) ^ (inputs_embeds is not None):
74
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
75
+
76
+ if input_ids is not None:
77
+ device = input_ids.device
78
+ seq_length = input_ids.shape[1]
79
+ batch_size = input_ids.size(0)
80
+ else:
81
+ device = inputs_embeds.device
82
+ seq_length = inputs_embeds.shape[1]
83
+ batch_size = inputs_embeds.size(0)
84
+
85
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
86
+ if cache_position is None:
87
+ cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
88
+
89
+ indices = None
90
+ if self.config._attn_implementation.startswith("flash_attention"):
91
+ if input_ids is None or attention_mask is None:
92
+ raise ValueError("Unpadding requires both input_ids and attention_mask")
93
+ with torch.no_grad():
94
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
95
+ features = _unpad_input(input_ids, attention_mask)
96
+ input_ids = features["input_ids"].to(device=device)
97
+ # roberta requires shifting position_ids by 2
98
+ position_ids = (features["position_ids"] + 2).to(device=device)
99
+ attention_mask = None
100
+ kwargs["cu_seq_lens_k"] = features["cu_seq_lens_k"].to(device=device)
101
+ kwargs["cu_seq_lens_q"] = features["cu_seq_lens_q"].to(device=device)
102
+ kwargs["max_length_k"] = features["max_length_k"]
103
+ kwargs["max_length_q"] = features["max_length_q"]
104
+
105
+ embedding_output = self.embeddings(
106
+ input_ids=input_ids,
107
+ position_ids=position_ids,
108
+ token_type_ids=token_type_ids,
109
+ inputs_embeds=inputs_embeds,
110
+ past_key_values_length=past_key_values_length,
111
+ )
112
+
113
+ attention_mask, encoder_attention_mask = self._create_attention_masks(
114
+ attention_mask=attention_mask,
115
+ encoder_attention_mask=encoder_attention_mask,
116
+ embedding_output=embedding_output,
117
+ encoder_hidden_states=encoder_hidden_states,
118
+ cache_position=cache_position,
119
+ past_key_values=past_key_values,
120
+ )
121
+
122
+ encoder_outputs = self.encoder(
123
+ embedding_output,
124
+ attention_mask=attention_mask,
125
+ encoder_hidden_states=encoder_hidden_states,
126
+ encoder_attention_mask=encoder_attention_mask,
127
+ past_key_values=past_key_values,
128
+ use_cache=use_cache,
129
+ cache_position=cache_position,
130
+ position_ids=position_ids,
131
+ **kwargs,
132
+ )
133
+
134
+ sequence_output = encoder_outputs.last_hidden_state
135
+ if self.config._attn_implementation.startswith("flash_attention"):
136
+ sequence_output = _pad_output(
137
+ inputs=sequence_output, indices=indices, batch=batch_size, seqlen=seq_length
138
+ )
139
+
140
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
141
+ return BaseModelOutputWithPoolingAndCrossAttentions(
142
+ last_hidden_state=sequence_output,
143
+ pooler_output=pooled_output,
144
+ past_key_values=encoder_outputs.past_key_values,
145
+ )
146
+
147
+
148
+ class UnpadRobertaForCausalLM(RobertaForCausalLM):
149
+
150
+ def __init__(self, config):
151
+ super().__init__(config)
152
+ self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
153
+ self.post_init()
154
+
155
+
156
+ class UnpadRobertaForMaskedLM(RobertaForMaskedLM):
157
+
158
+ def __init__(self, config):
159
+ super().__init__(config)
160
+ self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
161
+ self.post_init()
162
+
163
+
164
+ class UnpadRobertaForSequenceClassification(RobertaForSequenceClassification):
165
+
166
+ def __init__(self, config):
167
+ super().__init__(config)
168
+ self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
169
+ self.post_init()
170
+
171
+
172
+ class UnpadRobertaForTokenClassification(RobertaForTokenClassification):
173
+
174
+ def __init__(self, config):
175
+ super().__init__(config)
176
+ self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
177
+ self.post_init()
178
+
179
+
180
+ class UnpadRobertaForMultipleChoice(RobertaForMultipleChoice):
181
+
182
+ def __init__(self, config):
183
+ super().__init__(config)
184
+ self.roberta = UnpadRobertaModel(config)
185
+ self.post_init()
186
+
187
+
188
+ class UnpadRobertaForQuestionAnswering(RobertaForQuestionAnswering):
189
+
190
+ def __init__(self, config):
191
+ super().__init__(config)
192
+ self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
193
+ self.post_init()
194
+
195
+
196
+ def enable_roberta_unpadding():
197
+ RobertaModel.forward = UnpadRobertaModel.forward