sdadas commited on
Commit
c03cbed
·
verified ·
1 Parent(s): 8ef3829

Upload modeling_modernbert.py

Browse files
Files changed (1) hide show
  1. modeling_modernbert.py +152 -0
modeling_modernbert.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Unpack
2
+ import torch
3
+ from transformers import (
4
+ DataCollatorWithFlattening,
5
+ ModernBertModel,
6
+ ModernBertConfig,
7
+ ModernBertForMaskedLM,
8
+ ModernBertForSequenceClassification,
9
+ ModernBertForTokenClassification,
10
+ ModernBertForQuestionAnswering,
11
+ ModernBertForMultipleChoice
12
+ )
13
+ from transformers.masking_utils import create_bidirectional_mask, create_bidirectional_sliding_window_mask
14
+ from transformers.modeling_outputs import BaseModelOutput
15
+ from transformers.utils import TransformersKwargs
16
+
17
+
18
+ def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
19
+ collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True)
20
+ features = collator([{"input_ids": i[a.bool()].tolist()} for i, a in zip(input_ids, attention_mask)])
21
+ return features
22
+
23
+
24
+ def _pad_output(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int,) -> torch.Tensor:
25
+ if inputs.dim() == 3:
26
+ inputs = inputs.squeeze()
27
+ if inputs.dim() == 1:
28
+ output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
29
+ output[indices] = inputs
30
+ padded_inputs = output.view(batch, seqlen)
31
+ else:
32
+ _, *rest = inputs.shape
33
+ output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
34
+ output[indices] = inputs
35
+ padded_inputs = output.view(batch, seqlen, *rest)
36
+ return padded_inputs
37
+
38
+
39
+ class UnpadModernBertModel(ModernBertModel):
40
+
41
+ def __init__(self, config: ModernBertConfig):
42
+ super().__init__(config)
43
+
44
+ def forward(
45
+ self,
46
+ input_ids: torch.LongTensor | None = None,
47
+ attention_mask: torch.Tensor | None = None,
48
+ position_ids: torch.LongTensor | None = None,
49
+ inputs_embeds: torch.Tensor | None = None,
50
+ **kwargs: Unpack[TransformersKwargs],
51
+ ) -> BaseModelOutput:
52
+ if (input_ids is None) ^ (inputs_embeds is not None):
53
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
54
+
55
+ seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
56
+ batch_size = inputs_embeds.shape[0] if inputs_embeds is not None else input_ids.shape[0]
57
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
58
+
59
+ indices = None
60
+ if self.config._attn_implementation.startswith("flash_attention"):
61
+ if input_ids is None or attention_mask is None:
62
+ raise ValueError("Unpadding requires both input_ids and attention_mask")
63
+ with torch.no_grad():
64
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
65
+ features = _unpad_input(input_ids, attention_mask)
66
+ input_ids = features["input_ids"].to(device=device)
67
+ position_ids = features["position_ids"].to(device=device)
68
+ attention_mask = None
69
+ kwargs["cu_seq_lens_k"] = features["cu_seq_lens_k"].to(device=device)
70
+ kwargs["cu_seq_lens_q"] = features["cu_seq_lens_q"].to(device=device)
71
+ kwargs["max_length_k"] = features["max_length_k"]
72
+ kwargs["max_length_q"] = features["max_length_q"]
73
+
74
+ if position_ids is None:
75
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
76
+
77
+ hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
78
+
79
+ if not isinstance(attention_mask_mapping := attention_mask, dict):
80
+ mask_kwargs = {
81
+ "config": self.config,
82
+ "inputs_embeds": hidden_states,
83
+ "attention_mask": attention_mask,
84
+ }
85
+ attention_mask_mapping = {
86
+ "full_attention": create_bidirectional_mask(**mask_kwargs),
87
+ "sliding_attention": create_bidirectional_sliding_window_mask(**mask_kwargs),
88
+ }
89
+
90
+ position_embeddings = {}
91
+ for layer_type in self.config.layer_types:
92
+ position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
93
+
94
+ for encoder_layer in self.layers:
95
+ hidden_states = encoder_layer(
96
+ hidden_states,
97
+ attention_mask=attention_mask_mapping[encoder_layer.attention_type],
98
+ position_embeddings=position_embeddings[encoder_layer.attention_type],
99
+ **kwargs,
100
+ )
101
+
102
+ hidden_states = self.final_norm(hidden_states)
103
+ if self.config._attn_implementation.startswith("flash_attention"):
104
+ hidden_states = _pad_output(
105
+ inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
106
+ )
107
+
108
+ return BaseModelOutput(last_hidden_state=hidden_states)
109
+
110
+
111
+ class UnpadModernBertForMaskedLM(ModernBertForMaskedLM):
112
+
113
+ def __init__(self, config):
114
+ super().__init__(config)
115
+ self.model = UnpadModernBertModel(config)
116
+ self.post_init()
117
+
118
+
119
+ class UnpadModernBertForSequenceClassification(ModernBertForSequenceClassification):
120
+
121
+ def __init__(self, config):
122
+ super().__init__(config)
123
+ self.model = UnpadModernBertModel(config)
124
+ self.post_init()
125
+
126
+
127
+ class UnpadModernBertForTokenClassification(ModernBertForTokenClassification):
128
+
129
+ def __init__(self, config):
130
+ super().__init__(config)
131
+ self.model = UnpadModernBertModel(config)
132
+ self.post_init()
133
+
134
+
135
+ class UnpadModernBertForQuestionAnswering(ModernBertForQuestionAnswering):
136
+
137
+ def __init__(self, config):
138
+ super().__init__(config)
139
+ self.model = UnpadModernBertModel(config)
140
+ self.post_init()
141
+
142
+
143
+ class UnpadModernBertForMultipleChoice(ModernBertForMultipleChoice):
144
+
145
+ def __init__(self, config):
146
+ super().__init__(config)
147
+ self.model = UnpadModernBertModel(config)
148
+ self.post_init()
149
+
150
+
151
+ def enable_modernbert_unpadding():
152
+ ModernBertModel.forward = UnpadModernBertModel.forward