sdadas commited on
Commit
ce4368a
·
verified ·
1 Parent(s): f44155a

Update modeling_bert.py

Browse files
Files changed (1) hide show
  1. modeling_bert.py +174 -171
modeling_bert.py CHANGED
@@ -1,171 +1,174 @@
1
- from typing import Unpack
2
- import torch
3
- from transformers import (
4
- Cache,
5
- EncoderDecoderCache,
6
- DynamicCache,
7
- DataCollatorWithFlattening,
8
- BertModel, BertForMaskedLM, BertForSequenceClassification, BertForTokenClassification, RobertaForMultipleChoice,
9
- BertForMultipleChoice, BertForQuestionAnswering
10
- )
11
- from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
12
- from transformers.utils import TransformersKwargs
13
-
14
-
15
- def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
16
- collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True)
17
- features = collator([{"input_ids": i[a.bool()].tolist()} for i, a in zip(input_ids, attention_mask)])
18
- return features
19
-
20
-
21
- def _pad_output(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int,) -> torch.Tensor:
22
- if inputs.dim() == 3:
23
- inputs = inputs.squeeze()
24
- if inputs.dim() == 1:
25
- output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
26
- output[indices] = inputs
27
- padded_inputs = output.view(batch, seqlen)
28
- else:
29
- _, *rest = inputs.shape
30
- output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
31
- output[indices] = inputs
32
- padded_inputs = output.view(batch, seqlen, *rest)
33
- return padded_inputs
34
-
35
-
36
- class UnpadBertModel(BertModel):
37
- _no_split_modules = ["BertEmbeddings", "BertLayer"]
38
-
39
- def __init__(self, config, add_pooling_layer=True):
40
- super().__init__(config, add_pooling_layer)
41
-
42
- def forward(
43
- self,
44
- input_ids: torch.Tensor | None = None,
45
- attention_mask: torch.Tensor | None = None,
46
- token_type_ids: torch.Tensor | None = None,
47
- position_ids: torch.Tensor | None = None,
48
- inputs_embeds: torch.Tensor | None = None,
49
- encoder_hidden_states: torch.Tensor | None = None,
50
- encoder_attention_mask: torch.Tensor | None = None,
51
- past_key_values: Cache | None = None,
52
- use_cache: bool | None = None,
53
- **kwargs: Unpack[TransformersKwargs],
54
- ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
55
- if (input_ids is None) ^ (inputs_embeds is not None):
56
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
57
-
58
- if self.config.is_decoder:
59
- use_cache = use_cache if use_cache is not None else self.config.use_cache
60
- else:
61
- use_cache = False
62
-
63
- if use_cache and past_key_values is None:
64
- past_key_values = (
65
- EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
66
- if encoder_hidden_states is not None or self.config.is_encoder_decoder
67
- else DynamicCache(config=self.config)
68
- )
69
-
70
- past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
71
-
72
- device = input_ids.device
73
- batch_size = input_ids.shape[0]
74
- seq_length = input_ids.shape[1]
75
- indices = None
76
- if self.config._attn_implementation.startswith("flash_attention"):
77
- if input_ids is None or attention_mask is None:
78
- raise ValueError("Unpadding requires both input_ids and attention_mask")
79
- with torch.no_grad():
80
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
81
- features = _unpad_input(input_ids, attention_mask)
82
- input_ids = features["input_ids"].to(device=device)
83
- position_ids = features["position_ids"].to(device=device)
84
- attention_mask = None
85
- kwargs["cu_seq_lens_k"] = features["cu_seq_lens_k"].to(device=device)
86
- kwargs["cu_seq_lens_q"] = features["cu_seq_lens_q"].to(device=device)
87
- kwargs["max_length_k"] = features["max_length_k"]
88
- kwargs["max_length_q"] = features["max_length_q"]
89
-
90
- embedding_output = self.embeddings(
91
- input_ids=input_ids,
92
- position_ids=position_ids,
93
- token_type_ids=token_type_ids,
94
- inputs_embeds=inputs_embeds,
95
- past_key_values_length=past_key_values_length,
96
- )
97
-
98
- attention_mask, encoder_attention_mask = self._create_attention_masks(
99
- attention_mask=attention_mask,
100
- encoder_attention_mask=encoder_attention_mask,
101
- embedding_output=embedding_output,
102
- encoder_hidden_states=encoder_hidden_states,
103
- past_key_values=past_key_values,
104
- )
105
-
106
- encoder_outputs = self.encoder(
107
- embedding_output,
108
- attention_mask=attention_mask,
109
- encoder_hidden_states=encoder_hidden_states,
110
- encoder_attention_mask=encoder_attention_mask,
111
- past_key_values=past_key_values,
112
- use_cache=use_cache,
113
- position_ids=position_ids,
114
- **kwargs,
115
- )
116
- sequence_output = encoder_outputs.last_hidden_state
117
- if self.config._attn_implementation.startswith("flash_attention"):
118
- sequence_output = _pad_output(
119
- inputs=sequence_output, indices=indices, batch=batch_size, seqlen=seq_length
120
- )
121
-
122
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
123
- return BaseModelOutputWithPoolingAndCrossAttentions(
124
- last_hidden_state=sequence_output,
125
- pooler_output=pooled_output,
126
- past_key_values=encoder_outputs.past_key_values,
127
- )
128
-
129
-
130
- class UnpadBertForMaskedLM(BertForMaskedLM):
131
-
132
- def __init__(self, config):
133
- super().__init__(config)
134
- self.roberta = UnpadBertModel(config, add_pooling_layer=False)
135
- self.post_init()
136
-
137
-
138
- class UnpadBertForSequenceClassification(BertForSequenceClassification):
139
-
140
- def __init__(self, config):
141
- super().__init__(config)
142
- self.roberta = UnpadBertModel(config)
143
- self.post_init()
144
-
145
-
146
- class UnpadBertForTokenClassification(BertForTokenClassification):
147
-
148
- def __init__(self, config):
149
- super().__init__(config)
150
- self.roberta = UnpadBertModel(config)
151
- self.post_init()
152
-
153
-
154
- class UnpadBertForMultipleChoice(BertForMultipleChoice):
155
-
156
- def __init__(self, config):
157
- super().__init__(config)
158
- self.roberta = UnpadBertModel(config)
159
- self.post_init()
160
-
161
-
162
- class UnpadBertForQuestionAnswering(BertForQuestionAnswering):
163
-
164
- def __init__(self, config):
165
- super().__init__(config)
166
- self.roberta = UnpadBertModel(config, add_pooling_layer=False)
167
- self.post_init()
168
-
169
-
170
- def enable_bert_unpadding():
171
- BertModel.forward = UnpadBertModel.forward
 
 
 
 
1
+ from typing import Unpack
2
+ import torch
3
+ from transformers import (
4
+ Cache,
5
+ EncoderDecoderCache,
6
+ DynamicCache,
7
+ DataCollatorWithFlattening,
8
+ BertModel, BertForMaskedLM,
9
+ BertForSequenceClassification,
10
+ BertForTokenClassification,
11
+ BertForMultipleChoice,
12
+ BertForQuestionAnswering
13
+ )
14
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
15
+ from transformers.utils import TransformersKwargs
16
+
17
+
18
+ def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
19
+ collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True)
20
+ features = collator([{"input_ids": i[a.bool()].tolist()} for i, a in zip(input_ids, attention_mask)])
21
+ return features
22
+
23
+
24
+ def _pad_output(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int,) -> torch.Tensor:
25
+ if inputs.dim() == 3:
26
+ inputs = inputs.squeeze()
27
+ if inputs.dim() == 1:
28
+ output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
29
+ output[indices] = inputs
30
+ padded_inputs = output.view(batch, seqlen)
31
+ else:
32
+ _, *rest = inputs.shape
33
+ output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
34
+ output[indices] = inputs
35
+ padded_inputs = output.view(batch, seqlen, *rest)
36
+ return padded_inputs
37
+
38
+
39
+ class UnpadBertModel(BertModel):
40
+ _no_split_modules = ["BertEmbeddings", "BertLayer"]
41
+
42
+ def __init__(self, config, add_pooling_layer=True):
43
+ super().__init__(config, add_pooling_layer)
44
+
45
+ def forward(
46
+ self,
47
+ input_ids: torch.Tensor | None = None,
48
+ attention_mask: torch.Tensor | None = None,
49
+ token_type_ids: torch.Tensor | None = None,
50
+ position_ids: torch.Tensor | None = None,
51
+ inputs_embeds: torch.Tensor | None = None,
52
+ encoder_hidden_states: torch.Tensor | None = None,
53
+ encoder_attention_mask: torch.Tensor | None = None,
54
+ past_key_values: Cache | None = None,
55
+ use_cache: bool | None = None,
56
+ **kwargs: Unpack[TransformersKwargs],
57
+ ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
58
+ if (input_ids is None) ^ (inputs_embeds is not None):
59
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
60
+
61
+ if self.config.is_decoder:
62
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
63
+ else:
64
+ use_cache = False
65
+
66
+ if use_cache and past_key_values is None:
67
+ past_key_values = (
68
+ EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
69
+ if encoder_hidden_states is not None or self.config.is_encoder_decoder
70
+ else DynamicCache(config=self.config)
71
+ )
72
+
73
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
74
+
75
+ device = input_ids.device
76
+ batch_size = input_ids.shape[0]
77
+ seq_length = input_ids.shape[1]
78
+ indices = None
79
+ if self.config._attn_implementation.startswith("flash_attention"):
80
+ if input_ids is None or attention_mask is None:
81
+ raise ValueError("Unpadding requires both input_ids and attention_mask")
82
+ with torch.no_grad():
83
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
84
+ features = _unpad_input(input_ids, attention_mask)
85
+ input_ids = features["input_ids"].to(device=device)
86
+ position_ids = features["position_ids"].to(device=device)
87
+ attention_mask = None
88
+ kwargs["cu_seq_lens_k"] = features["cu_seq_lens_k"].to(device=device)
89
+ kwargs["cu_seq_lens_q"] = features["cu_seq_lens_q"].to(device=device)
90
+ kwargs["max_length_k"] = features["max_length_k"]
91
+ kwargs["max_length_q"] = features["max_length_q"]
92
+
93
+ embedding_output = self.embeddings(
94
+ input_ids=input_ids,
95
+ position_ids=position_ids,
96
+ token_type_ids=token_type_ids,
97
+ inputs_embeds=inputs_embeds,
98
+ past_key_values_length=past_key_values_length,
99
+ )
100
+
101
+ attention_mask, encoder_attention_mask = self._create_attention_masks(
102
+ attention_mask=attention_mask,
103
+ encoder_attention_mask=encoder_attention_mask,
104
+ embedding_output=embedding_output,
105
+ encoder_hidden_states=encoder_hidden_states,
106
+ past_key_values=past_key_values,
107
+ )
108
+
109
+ encoder_outputs = self.encoder(
110
+ embedding_output,
111
+ attention_mask=attention_mask,
112
+ encoder_hidden_states=encoder_hidden_states,
113
+ encoder_attention_mask=encoder_attention_mask,
114
+ past_key_values=past_key_values,
115
+ use_cache=use_cache,
116
+ position_ids=position_ids,
117
+ **kwargs,
118
+ )
119
+ sequence_output = encoder_outputs.last_hidden_state
120
+ if self.config._attn_implementation.startswith("flash_attention"):
121
+ sequence_output = _pad_output(
122
+ inputs=sequence_output, indices=indices, batch=batch_size, seqlen=seq_length
123
+ )
124
+
125
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
126
+ return BaseModelOutputWithPoolingAndCrossAttentions(
127
+ last_hidden_state=sequence_output,
128
+ pooler_output=pooled_output,
129
+ past_key_values=encoder_outputs.past_key_values,
130
+ )
131
+
132
+
133
+ class UnpadBertForMaskedLM(BertForMaskedLM):
134
+
135
+ def __init__(self, config):
136
+ super().__init__(config)
137
+ self.bert = UnpadBertModel(config, add_pooling_layer=False)
138
+ self.post_init()
139
+
140
+
141
+ class UnpadBertForSequenceClassification(BertForSequenceClassification):
142
+
143
+ def __init__(self, config):
144
+ super().__init__(config)
145
+ self.bert = UnpadBertModel(config)
146
+ self.post_init()
147
+
148
+
149
+ class UnpadBertForTokenClassification(BertForTokenClassification):
150
+
151
+ def __init__(self, config):
152
+ super().__init__(config)
153
+ self.bert = UnpadBertModel(config)
154
+ self.post_init()
155
+
156
+
157
+ class UnpadBertForMultipleChoice(BertForMultipleChoice):
158
+
159
+ def __init__(self, config):
160
+ super().__init__(config)
161
+ self.bert = UnpadBertModel(config)
162
+ self.post_init()
163
+
164
+
165
+ class UnpadBertForQuestionAnswering(BertForQuestionAnswering):
166
+
167
+ def __init__(self, config):
168
+ super().__init__(config)
169
+ self.bert = UnpadBertModel(config, add_pooling_layer=False)
170
+ self.post_init()
171
+
172
+
173
+ def enable_bert_unpadding():
174
+ BertModel.forward = UnpadBertModel.forward