unpad-impl / modeling_modernbert.py
sdadas's picture
Upload modeling_modernbert.py
c03cbed verified
from typing import Unpack
import torch
from transformers import (
DataCollatorWithFlattening,
ModernBertModel,
ModernBertConfig,
ModernBertForMaskedLM,
ModernBertForSequenceClassification,
ModernBertForTokenClassification,
ModernBertForQuestionAnswering,
ModernBertForMultipleChoice
)
from transformers.masking_utils import create_bidirectional_mask, create_bidirectional_sliding_window_mask
from transformers.modeling_outputs import BaseModelOutput
from transformers.utils import TransformersKwargs
def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True)
features = collator([{"input_ids": i[a.bool()].tolist()} for i, a in zip(input_ids, attention_mask)])
return features
def _pad_output(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int,) -> torch.Tensor:
if inputs.dim() == 3:
inputs = inputs.squeeze()
if inputs.dim() == 1:
output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
output[indices] = inputs
padded_inputs = output.view(batch, seqlen)
else:
_, *rest = inputs.shape
output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
output[indices] = inputs
padded_inputs = output.view(batch, seqlen, *rest)
return padded_inputs
class UnpadModernBertModel(ModernBertModel):
def __init__(self, config: ModernBertConfig):
super().__init__(config)
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
batch_size = inputs_embeds.shape[0] if inputs_embeds is not None else input_ids.shape[0]
device = input_ids.device if input_ids is not None else inputs_embeds.device
indices = None
if self.config._attn_implementation.startswith("flash_attention"):
if input_ids is None or attention_mask is None:
raise ValueError("Unpadding requires both input_ids and attention_mask")
with torch.no_grad():
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
features = _unpad_input(input_ids, attention_mask)
input_ids = features["input_ids"].to(device=device)
position_ids = features["position_ids"].to(device=device)
attention_mask = None
kwargs["cu_seq_lens_k"] = features["cu_seq_lens_k"].to(device=device)
kwargs["cu_seq_lens_q"] = features["cu_seq_lens_q"].to(device=device)
kwargs["max_length_k"] = features["max_length_k"]
kwargs["max_length_q"] = features["max_length_q"]
if position_ids is None:
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
if not isinstance(attention_mask_mapping := attention_mask, dict):
mask_kwargs = {
"config": self.config,
"inputs_embeds": hidden_states,
"attention_mask": attention_mask,
}
attention_mask_mapping = {
"full_attention": create_bidirectional_mask(**mask_kwargs),
"sliding_attention": create_bidirectional_sliding_window_mask(**mask_kwargs),
}
position_embeddings = {}
for layer_type in self.config.layer_types:
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask=attention_mask_mapping[encoder_layer.attention_type],
position_embeddings=position_embeddings[encoder_layer.attention_type],
**kwargs,
)
hidden_states = self.final_norm(hidden_states)
if self.config._attn_implementation.startswith("flash_attention"):
hidden_states = _pad_output(
inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
)
return BaseModelOutput(last_hidden_state=hidden_states)
class UnpadModernBertForMaskedLM(ModernBertForMaskedLM):
def __init__(self, config):
super().__init__(config)
self.model = UnpadModernBertModel(config)
self.post_init()
class UnpadModernBertForSequenceClassification(ModernBertForSequenceClassification):
def __init__(self, config):
super().__init__(config)
self.model = UnpadModernBertModel(config)
self.post_init()
class UnpadModernBertForTokenClassification(ModernBertForTokenClassification):
def __init__(self, config):
super().__init__(config)
self.model = UnpadModernBertModel(config)
self.post_init()
class UnpadModernBertForQuestionAnswering(ModernBertForQuestionAnswering):
def __init__(self, config):
super().__init__(config)
self.model = UnpadModernBertModel(config)
self.post_init()
class UnpadModernBertForMultipleChoice(ModernBertForMultipleChoice):
def __init__(self, config):
super().__init__(config)
self.model = UnpadModernBertModel(config)
self.post_init()
def enable_modernbert_unpadding():
ModernBertModel.forward = UnpadModernBertModel.forward