|
import copy |
|
from typing import Optional, List, Union, Tuple |
|
|
|
from transformers import MBartForCausalLM, MBartConfig |
|
from torch import nn |
|
from transformers.activations import ACT2FN |
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask |
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions |
|
from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel, MBartDecoder, MBartLearnedPositionalEmbedding, MBartDecoderLayer |
|
from surya.model.ordering.config import MBartOrderConfig |
|
import torch |
|
import math |
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
From llama |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
class MBartGQAttention(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
num_heads: int, |
|
num_kv_heads: int, |
|
dropout: float = 0.0, |
|
is_decoder: bool = False, |
|
bias: bool = True, |
|
is_causal: bool = False, |
|
config: Optional[MBartConfig] = None, |
|
): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.num_kv_heads = num_kv_heads |
|
self.num_kv_groups = self.num_heads // self.num_kv_heads |
|
|
|
assert self.num_heads % self.num_kv_heads == 0, f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})" |
|
assert embed_dim % self.num_kv_heads == 0, f"embed_dim ({self.embed_dim}) must be divisible by num_kv_heads ({self.num_kv_heads})" |
|
|
|
self.dropout = dropout |
|
self.head_dim = embed_dim // num_heads |
|
self.config = config |
|
|
|
if (self.head_dim * num_heads) != self.embed_dim: |
|
raise ValueError( |
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" |
|
f" and `num_heads`: {num_heads})." |
|
) |
|
self.scaling = self.head_dim**-0.5 |
|
self.is_decoder = is_decoder |
|
self.is_causal = is_causal |
|
|
|
self.k_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias) |
|
self.v_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias) |
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
def _shape_key_value(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
return tensor.view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
key_value_states: Optional[torch.Tensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
layer_head_mask: Optional[torch.Tensor] = None, |
|
output_attentions: bool = False, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
"""Input shape: Batch x Time x Channel""" |
|
|
|
|
|
|
|
is_cross_attention = key_value_states is not None |
|
|
|
bsz, tgt_len, _ = hidden_states.size() |
|
|
|
|
|
query_states = self.q_proj(hidden_states) * self.scaling |
|
|
|
|
|
|
|
|
|
if ( |
|
is_cross_attention |
|
and past_key_value is not None |
|
and past_key_value[0].shape[2] == key_value_states.shape[1] |
|
): |
|
|
|
key_states = past_key_value[0] |
|
value_states = past_key_value[1] |
|
elif is_cross_attention: |
|
|
|
key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz) |
|
value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz) |
|
elif past_key_value is not None: |
|
|
|
key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) |
|
value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) |
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
else: |
|
|
|
key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) |
|
value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) |
|
|
|
if self.is_decoder: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
past_key_value = (key_states, value_states) |
|
|
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) |
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_kv_groups) |
|
value_states = repeat_kv(value_states, self.num_kv_groups) |
|
key_states = key_states.reshape(*proj_shape) |
|
value_states = value_states.reshape(*proj_shape) |
|
|
|
src_len = key_states.size(1) |
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) |
|
|
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): |
|
raise ValueError( |
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" |
|
f" {attn_weights.size()}" |
|
) |
|
|
|
if attention_mask is not None: |
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" |
|
) |
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
|
|
|
if layer_head_mask is not None: |
|
if layer_head_mask.size() != (self.num_heads,): |
|
raise ValueError( |
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" |
|
f" {layer_head_mask.size()}" |
|
) |
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
if output_attentions: |
|
|
|
|
|
|
|
|
|
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
|
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) |
|
else: |
|
attn_weights_reshaped = None |
|
|
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
|
|
|
attn_output = torch.bmm(attn_probs, value_states) |
|
|
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) |
|
attn_output = attn_output.transpose(1, 2) |
|
|
|
|
|
|
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) |
|
|
|
attn_output = self.out_proj(attn_output) |
|
|
|
return attn_output, attn_weights_reshaped, past_key_value |
|
|
|
|
|
MBART_ATTENTION_CLASSES = { |
|
"eager": MBartGQAttention, |
|
"flash_attention_2": None |
|
} |
|
|
|
|
|
class MBartOrderDecoderLayer(MBartDecoderLayer): |
|
def __init__(self, config: MBartConfig): |
|
nn.Module.__init__(self) |
|
self.embed_dim = config.d_model |
|
|
|
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation]( |
|
embed_dim=self.embed_dim, |
|
num_heads=config.decoder_attention_heads, |
|
num_kv_heads=config.kv_heads, |
|
dropout=config.attention_dropout, |
|
is_decoder=True, |
|
is_causal=True, |
|
config=config, |
|
) |
|
self.dropout = config.dropout |
|
self.activation_fn = ACT2FN[config.activation_function] |
|
self.activation_dropout = config.activation_dropout |
|
|
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
|
self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation]( |
|
self.embed_dim, |
|
config.decoder_attention_heads, |
|
num_kv_heads=config.kv_heads, |
|
dropout=config.attention_dropout, |
|
is_decoder=True, |
|
config=config, |
|
) |
|
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
|
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) |
|
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) |
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
|
|
|
class BboxEmbedding(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.x1_embed = nn.Embedding(config.max_width, config.d_model) |
|
self.y1_embed = nn.Embedding(config.max_height, config.d_model) |
|
self.x2_embed = nn.Embedding(config.max_width, config.d_model) |
|
self.y2_embed = nn.Embedding(config.max_height, config.d_model) |
|
self.w_embed = nn.Embedding(config.max_width, config.d_model) |
|
self.h_embed = nn.Embedding(config.max_height, config.d_model) |
|
self.cx_embed = nn.Embedding(config.max_width, config.d_model) |
|
self.cy_embed = nn.Embedding(config.max_height, config.d_model) |
|
self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.d_model) |
|
|
|
def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor, past_key_values_length: int): |
|
x1, y1, x2, y2 = boxes.unbind(dim=-1) |
|
|
|
w = x2 - x1 |
|
h = y2 - y1 |
|
|
|
cx = (x1 + x2) / 2 |
|
cy = (y1 + y2) / 2 |
|
cx = cx.long() |
|
cy = cy.long() |
|
|
|
coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2) |
|
embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) |
|
|
|
|
|
if past_key_values_length == 0: |
|
for j in range(embedded.shape[0]): |
|
box_start = input_box_counts[j, 0] |
|
box_end = input_box_counts[j, 1] - 1 |
|
box_count = box_end - box_start |
|
embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count] |
|
|
|
return embedded |
|
|
|
|
|
class MBartOrderDecoder(MBartDecoder): |
|
def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None): |
|
MBartPreTrainedModel.__init__(self, config) |
|
self.dropout = config.dropout |
|
self.layerdrop = config.decoder_layerdrop |
|
self.padding_idx = config.pad_token_id |
|
self.max_target_positions = config.max_position_embeddings |
|
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 |
|
|
|
self.embed_tokens = BboxEmbedding(config) if embed_tokens is None else embed_tokens |
|
|
|
if embed_tokens is not None: |
|
self.embed_tokens.weight = embed_tokens.weight |
|
|
|
self.embed_positions = MBartLearnedPositionalEmbedding( |
|
config.max_position_embeddings, |
|
config.d_model, |
|
) |
|
|
|
self.layers = nn.ModuleList([MBartOrderDecoderLayer(config) for _ in range(config.decoder_layers)]) |
|
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" |
|
self.layernorm_embedding = nn.LayerNorm(config.d_model) |
|
self.layer_norm = nn.LayerNorm(config.d_model) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_boxes: torch.LongTensor = None, |
|
input_boxes_mask: Optional[torch.Tensor] = None, |
|
input_boxes_counts: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if input_boxes is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") |
|
elif input_boxes is not None: |
|
input = input_boxes |
|
input_shape = input_boxes.size()[:-1] |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
input = inputs_embeds[:, :, -1] |
|
else: |
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") |
|
|
|
|
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_boxes, input_boxes_counts, past_key_values_length) * self.embed_scale |
|
|
|
if self._use_flash_attention_2: |
|
|
|
attention_mask = input_boxes_mask if (input_boxes_mask is not None and 0 in input_boxes_mask) else None |
|
else: |
|
|
|
attention_mask = _prepare_4d_causal_attention_mask( |
|
input_boxes_mask, input_shape, inputs_embeds, past_key_values_length |
|
) |
|
|
|
if past_key_values_length == 0: |
|
box_ends = input_boxes_counts[:, 1] |
|
box_starts = input_boxes_counts[:, 0] |
|
input_shape_arranged = torch.arange(input_shape[1], device=attention_mask.device)[None, :] |
|
|
|
|
|
boxes_end_mask = input_shape_arranged < box_ends[:, None] |
|
boxes_start_mask = input_shape_arranged >= box_starts[:, None] |
|
boxes_mask = boxes_end_mask & boxes_start_mask |
|
boxes_mask = boxes_mask.unsqueeze(1).unsqueeze(1) |
|
attention_mask = attention_mask.masked_fill(boxes_mask, 0) |
|
|
|
|
|
if encoder_hidden_states is not None and encoder_attention_mask is not None: |
|
if self._use_flash_attention_2: |
|
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None |
|
else: |
|
|
|
encoder_attention_mask = _prepare_4d_attention_mask( |
|
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] |
|
) |
|
|
|
|
|
positions = self.embed_positions(input, past_key_values_length) |
|
|
|
hidden_states = inputs_embeds + positions.to(inputs_embeds.device) |
|
hidden_states = self.layernorm_embedding(hidden_states) |
|
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
if use_cache: |
|
use_cache = False |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None |
|
next_decoder_cache = () if use_cache else None |
|
|
|
|
|
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): |
|
if attn_mask is not None: |
|
if attn_mask.size()[0] != len(self.layers): |
|
raise ValueError( |
|
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" |
|
f" {attn_mask.size()[0]}." |
|
) |
|
for idx, decoder_layer in enumerate(self.layers): |
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
if self.training: |
|
dropout_probability = torch.rand([]) |
|
if dropout_probability < self.layerdrop: |
|
continue |
|
|
|
past_key_value = past_key_values[idx] if past_key_values is not None else None |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
decoder_layer.__call__, |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
head_mask[idx] if head_mask is not None else None, |
|
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, |
|
None, |
|
output_attentions, |
|
use_cache, |
|
) |
|
else: |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
|
cross_attn_layer_head_mask=( |
|
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None |
|
), |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
if use_cache: |
|
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
if encoder_hidden_states is not None: |
|
all_cross_attentions += (layer_outputs[2],) |
|
|
|
hidden_states = self.layer_norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] |
|
if v is not None |
|
) |
|
return BaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
cross_attentions=all_cross_attentions, |
|
) |
|
|
|
|
|
class MBartOrderDecoderWrapper(MBartPreTrainedModel): |
|
""" |
|
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is |
|
used in combination with the [`EncoderDecoderModel`] framework. |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.decoder = MBartOrderDecoder(config) |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.decoder(*args, **kwargs) |
|
|
|
|
|
class MBartOrder(MBartForCausalLM): |
|
config_class = MBartOrderConfig |
|
_tied_weights_keys = [] |
|
|
|
def __init__(self, config, **kwargs): |
|
config = copy.deepcopy(config) |
|
config.is_decoder = True |
|
config.is_encoder_decoder = False |
|
MBartPreTrainedModel.__init__(self, config) |
|
self.model = MBartOrderDecoderWrapper(config) |
|
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_boxes: torch.LongTensor = None, |
|
input_boxes_mask: Optional[torch.Tensor] = None, |
|
input_boxes_counts: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs |
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.model.decoder( |
|
input_boxes=input_boxes, |
|
input_boxes_mask=input_boxes_mask, |
|
input_boxes_counts=input_boxes_counts, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
head_mask=head_mask, |
|
cross_attn_head_mask=cross_attn_head_mask, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
logits = self.lm_head(outputs[0]) |
|
|
|
loss = None |
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
cross_attentions=outputs.cross_attentions, |
|
) |