ketanmore's picture
Upload folder using huggingface_hub
2720487 verified
import copy
from typing import Optional, List, Union, Tuple
from transformers import MBartForCausalLM, MBartConfig
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel, MBartDecoder, MBartLearnedPositionalEmbedding, MBartDecoderLayer
from surya.model.ordering.config import MBartOrderConfig
import torch
import math
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
From llama
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class MBartGQAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
num_kv_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[MBartConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
assert self.num_heads % self.num_kv_heads == 0, f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
assert embed_dim % self.num_kv_heads == 0, f"embed_dim ({self.embed_dim}) must be divisible by num_kv_heads ({self.num_kv_heads})"
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _shape_key_value(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
# Expand kv heads, then match query shape
key_states = repeat_kv(key_states, self.num_kv_groups)
value_states = repeat_kv(value_states, self.num_kv_groups)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
MBART_ATTENTION_CLASSES = {
"eager": MBartGQAttention,
"flash_attention_2": None
}
class MBartOrderDecoderLayer(MBartDecoderLayer):
def __init__(self, config: MBartConfig):
nn.Module.__init__(self)
self.embed_dim = config.d_model
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
num_kv_heads=config.kv_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
num_kv_heads=config.kv_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
class BboxEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.x1_embed = nn.Embedding(config.max_width, config.d_model)
self.y1_embed = nn.Embedding(config.max_height, config.d_model)
self.x2_embed = nn.Embedding(config.max_width, config.d_model)
self.y2_embed = nn.Embedding(config.max_height, config.d_model)
self.w_embed = nn.Embedding(config.max_width, config.d_model)
self.h_embed = nn.Embedding(config.max_height, config.d_model)
self.cx_embed = nn.Embedding(config.max_width, config.d_model)
self.cy_embed = nn.Embedding(config.max_height, config.d_model)
self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.d_model)
def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor, past_key_values_length: int):
x1, y1, x2, y2 = boxes.unbind(dim=-1)
# Shape is (batch_size, num_boxes/seq len, d_model)
w = x2 - x1
h = y2 - y1
# Center x and y in torch long tensors
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
cx = cx.long()
cy = cy.long()
coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy)
# Add in positional embeddings for the boxes
if past_key_values_length == 0:
for j in range(embedded.shape[0]):
box_start = input_box_counts[j, 0]
box_end = input_box_counts[j, 1] - 1 # Skip the sep token
box_count = box_end - box_start
embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count]
return embedded
class MBartOrderDecoder(MBartDecoder):
def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
MBartPreTrainedModel.__init__(self, config)
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = BboxEmbedding(config) if embed_tokens is None else embed_tokens
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = MBartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
)
# Language-specific MoE goes at second and second-to-last layer
self.layers = nn.ModuleList([MBartOrderDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.layer_norm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_boxes: torch.LongTensor = None,
input_boxes_mask: Optional[torch.Tensor] = None,
input_boxes_counts: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_boxes is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_boxes is not None:
input = input_boxes
input_shape = input_boxes.size()[:-1] # Shape (batch_size, num_boxes)
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_boxes, input_boxes_counts, past_key_values_length) * self.embed_scale
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = input_boxes_mask if (input_boxes_mask is not None and 0 in input_boxes_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
input_boxes_mask, input_shape, inputs_embeds, past_key_values_length
)
if past_key_values_length == 0:
box_ends = input_boxes_counts[:, 1]
box_starts = input_boxes_counts[:, 0]
input_shape_arranged = torch.arange(input_shape[1], device=attention_mask.device)[None, :]
# Enable all boxes to attend to each other (before the sep token)
# Ensure that the boxes are not attending to the padding tokens
boxes_end_mask = input_shape_arranged < box_ends[:, None]
boxes_start_mask = input_shape_arranged >= box_starts[:, None]
boxes_mask = boxes_end_mask & boxes_start_mask
boxes_mask = boxes_mask.unsqueeze(1).unsqueeze(1) # Enable proper broadcasting
attention_mask = attention_mask.masked_fill(boxes_mask, 0)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self._use_flash_attention_2:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions
positions = self.embed_positions(input, past_key_values_length)
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != len(self.layers):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {attn_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
class MBartOrderDecoderWrapper(MBartPreTrainedModel):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the [`EncoderDecoderModel`] framework.
"""
def __init__(self, config):
super().__init__(config)
self.decoder = MBartOrderDecoder(config)
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
class MBartOrder(MBartForCausalLM):
config_class = MBartOrderConfig
_tied_weights_keys = []
def __init__(self, config, **kwargs):
config = copy.deepcopy(config)
config.is_decoder = True
config.is_encoder_decoder = False
MBartPreTrainedModel.__init__(self, config)
self.model = MBartOrderDecoderWrapper(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_boxes: torch.LongTensor = None,
input_boxes_mask: Optional[torch.Tensor] = None,
input_boxes_counts: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_boxes=input_boxes,
input_boxes_mask=input_boxes_mask,
input_boxes_counts=input_boxes_counts,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0])
loss = None
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)