|
"""LoRDCoder model class, based on GPT model. |
|
|
|
License: Apache-2.0 |
|
""" |
|
import math |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
from transformers.activations import ACT2FN |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPastAndCrossAttentions, |
|
CausalLMOutputWithCrossAttentions, |
|
SequenceClassifierOutputWithPast, |
|
TokenClassifierOutput, |
|
) |
|
from transformers.modeling_utils import PreTrainedModel |
|
from .configuration_lordcoder_v0 import LoRDCoderConfig |
|
|
|
|
|
|
|
|
|
@torch.jit.script |
|
def upcast_masked_softmax( |
|
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype |
|
): |
|
input_dtype = x.dtype |
|
x = x.to(softmax_dtype) * scale |
|
x = torch.where(mask, x, mask_value) |
|
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) |
|
return x |
|
|
|
|
|
@torch.jit.script |
|
def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): |
|
input_dtype = x.dtype |
|
x = x.to(softmax_dtype) * scale |
|
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) |
|
return x |
|
|
|
|
|
@torch.jit.script |
|
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): |
|
x = torch.where(mask, x, mask_value) |
|
x = torch.nn.functional.softmax(x, dim=-1) |
|
return x |
|
|
|
|
|
class LoRDCoderAttention(nn.Module): |
|
def __init__(self, config, is_cross_attention=False, layer_idx=None): |
|
super().__init__() |
|
self.mask_value = None |
|
|
|
self.multi_query = config.multi_query |
|
self.embed_dim = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = self.embed_dim // self.num_heads |
|
self.kv_heads = 1 if self.multi_query else self.num_heads |
|
self.kv_dim = self.kv_heads * self.head_dim |
|
self.split_size = self.embed_dim |
|
if self.head_dim * self.num_heads != self.embed_dim: |
|
raise ValueError( |
|
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" |
|
f" {self.num_heads})." |
|
) |
|
|
|
self.scale_attn_weights = config.scale_attn_weights |
|
self.is_cross_attention = is_cross_attention |
|
|
|
self.layer_idx = layer_idx |
|
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 |
|
self.scale_attention_softmax_in_fp32 = ( |
|
config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32 |
|
) |
|
|
|
if self.is_cross_attention: |
|
raise NotImplementedError("Cross Attention not supported.") |
|
if self.multi_query: |
|
raise NotImplementedError("Multi-Query Attention not supported for cross_attention") |
|
|
|
self.c_attn = nn.Linear(self.embed_dim, 2 * self.embed_dim) |
|
self.q_attn = nn.Linear(self.embed_dim, self.embed_dim) |
|
else: |
|
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim) |
|
|
|
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
self.attn_dropout = nn.Dropout(config.attn_pdrop) |
|
self.resid_dropout = nn.Dropout(config.resid_pdrop) |
|
|
|
def _get_mask_value(self, device, dtype): |
|
|
|
if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: |
|
self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) |
|
return self.mask_value |
|
|
|
def _attn(self, query, key, value, attention_mask=None, head_mask=None): |
|
dtype = query.dtype |
|
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype |
|
upcast = dtype != softmax_dtype |
|
|
|
unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 |
|
scale_factor = unscale**-1 |
|
if self.scale_attn_weights: |
|
scale_factor /= self.head_dim**0.5 |
|
|
|
|
|
|
|
query_shape = query.shape |
|
batch_size = query_shape[0] |
|
key_length = key.size(-1) |
|
if self.multi_query: |
|
|
|
|
|
query_length = query_shape[1] |
|
attn_shape = (batch_size, query_length, self.num_heads, key_length) |
|
attn_view = (batch_size, query_length * self.num_heads, key_length) |
|
|
|
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) |
|
else: |
|
|
|
|
|
query_length = query_shape[2] |
|
attn_shape = (batch_size, self.num_heads, query_length, key_length) |
|
attn_view = (batch_size * self.num_heads, query_length, key_length) |
|
|
|
query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) |
|
|
|
key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) |
|
|
|
attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) |
|
if query.device.type == "cpu": |
|
|
|
|
|
|
|
attn_weights = torch.zeros_like(attn_weights) |
|
beta = 1 |
|
else: |
|
beta = 0 |
|
attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) |
|
|
|
if upcast: |
|
|
|
|
|
if attention_mask is None: |
|
attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) |
|
else: |
|
mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) |
|
|
|
attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype) |
|
else: |
|
if attention_mask is not None: |
|
mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) |
|
|
|
|
|
attn_weights = torch.where(attention_mask, attn_weights, mask_value) |
|
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) |
|
|
|
attn_weights = self.attn_dropout(attn_weights) |
|
|
|
|
|
if head_mask is not None: |
|
if self.multi_query: |
|
head_mask = head_mask.transpose(1, 2) |
|
attn_weights = attn_weights * head_mask |
|
|
|
if self.multi_query: |
|
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) |
|
else: |
|
attn_output = torch.matmul(attn_weights, value) |
|
|
|
return attn_output, attn_weights |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
layer_past: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
use_cache: Optional[bool] = False, |
|
output_attentions: Optional[bool] = False, |
|
) -> Union[ |
|
Tuple[torch.Tensor, Optional[torch.Tensor]], |
|
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], |
|
]: |
|
if encoder_hidden_states is not None: |
|
if not hasattr(self, "q_attn") or not self.is_cross_attention: |
|
raise ValueError( |
|
"If class is used as cross attention, the weights `q_attn` have to be defined. " |
|
"Please make sure to instantiate class with `LoRDCoderAttention(..., is_cross_attention=True)`." |
|
) |
|
|
|
query = self.q_attn(hidden_states) |
|
key_value = self.c_attn(encoder_hidden_states) |
|
attention_mask = encoder_attention_mask |
|
elif self.multi_query: |
|
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) |
|
else: |
|
|
|
|
|
|
|
query, key_value = ( |
|
self.c_attn(hidden_states) |
|
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) |
|
.transpose(1, 2) |
|
.split((self.head_dim, 2 * self.head_dim), dim=3) |
|
) |
|
|
|
if layer_past is not None: |
|
key_value = torch.cat((layer_past, key_value), dim=-2) |
|
present = key_value if use_cache else None |
|
|
|
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) |
|
|
|
attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) |
|
|
|
if not self.multi_query: |
|
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) |
|
attn_output = self.c_proj(attn_output) |
|
attn_output = self.resid_dropout(attn_output) |
|
|
|
outputs = (attn_output, present) |
|
if output_attentions: |
|
if self.multi_query: |
|
|
|
attn_weights = attn_weights.transpose(1, 2) |
|
outputs += (attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
class LoRDCoderMLP(nn.Module): |
|
def __init__(self, intermediate_size, config): |
|
super().__init__() |
|
embed_dim = config.hidden_size |
|
self.gate_dim = config.gate_dim |
|
|
|
self.c_fc = torch.nn.Linear(in_features=embed_dim, out_features=intermediate_size, bias=True) |
|
self.c_gate = torch.nn.Linear(in_features=intermediate_size, out_features=self.gate_dim, bias=True) |
|
self.c_proj = torch.nn.Linear(in_features=self.gate_dim, out_features=embed_dim, bias=True) |
|
|
|
self.act = ACT2FN[config.activation_function] |
|
self.dropout = nn.Dropout(config.resid_pdrop) |
|
|
|
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: |
|
hidden_states = self.c_fc(hidden_states) |
|
hidden_states = self.c_gate(self.act(hidden_states)) |
|
hidden_states = self.c_proj(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class LoRDCoderBlock(nn.Module): |
|
def __init__(self, config, layer_idx=None): |
|
super().__init__() |
|
hidden_size = config.hidden_size |
|
self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size |
|
|
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
|
self.attn = LoRDCoderAttention(config, layer_idx=layer_idx) |
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
|
|
|
if config.add_cross_attention: |
|
if config.multi_query: |
|
raise NotImplementedError("Cross-attention not implemented for MQA") |
|
self.crossattention = LoRDCoderAttention(config, is_cross_attention=True, layer_idx=layer_idx) |
|
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
|
|
|
self.mlp = LoRDCoderMLP(self.inner_dim, config) |
|
|
|
def forward( |
|
self, |
|
hidden_states: Optional[Tuple[torch.Tensor]], |
|
layer_past: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
use_cache: Optional[bool] = False, |
|
output_attentions: Optional[bool] = False, |
|
) -> Union[ |
|
Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] |
|
]: |
|
residual = hidden_states |
|
hidden_states = self.ln_1(hidden_states) |
|
attn_outputs = self.attn( |
|
hidden_states, |
|
layer_past=layer_past, |
|
attention_mask=attention_mask, |
|
head_mask=head_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
attn_output = attn_outputs[0] |
|
outputs = attn_outputs[1:] |
|
|
|
hidden_states = attn_output + residual |
|
|
|
if encoder_hidden_states is not None: |
|
|
|
if not hasattr(self, "crossattention"): |
|
raise ValueError( |
|
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " |
|
"cross-attention layers by setting `config.add_cross_attention=True`" |
|
) |
|
residual = hidden_states |
|
hidden_states = self.ln_cross_attn(hidden_states) |
|
cross_attn_outputs = self.crossattention( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
head_mask=head_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
output_attentions=output_attentions, |
|
) |
|
attn_output = cross_attn_outputs[0] |
|
|
|
hidden_states = residual + attn_output |
|
outputs = outputs + cross_attn_outputs[2:] |
|
|
|
residual = hidden_states |
|
hidden_states = self.ln_2(hidden_states) |
|
feed_forward_hidden_states = self.mlp(hidden_states) |
|
|
|
hidden_states = residual + feed_forward_hidden_states |
|
|
|
if use_cache: |
|
outputs = (hidden_states,) + outputs |
|
else: |
|
outputs = (hidden_states,) + outputs[1:] |
|
|
|
return outputs |
|
|
|
|
|
class LoRDCoderPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = LoRDCoderConfig |
|
base_model_prefix = "transformer" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["LoRDCoderBlock"] |
|
_skip_keys_device_placement = "past_key_values" |
|
|
|
def __init__(self, *inputs, **kwargs): |
|
super().__init__(*inputs, **kwargs) |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights.""" |
|
if isinstance(module, LoRDCoderMLP): |
|
|
|
|
|
|
|
|
|
|
|
|
|
module.c_proj.weight.data.normal_( |
|
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) |
|
) |
|
module.c_proj._is_hf_initialized = True |
|
elif isinstance(module, LoRDCoderAttention): |
|
|
|
|
|
|
|
|
|
|
|
|
|
module.c_proj.weight.data.normal_( |
|
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) |
|
) |
|
module.c_proj.weight.data.normal_( |
|
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) |
|
) |
|
module.c_proj._is_hf_initialized = True |
|
elif isinstance(module, nn.Linear): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if isinstance(module, LoRDCoderModel): |
|
module.gradient_checkpointing = value |
|
|
|
|
|
class LoRDCoderModel(LoRDCoderPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.multi_query = config.multi_query |
|
self.embed_dim = config.hidden_size |
|
|
|
self.wte = nn.Embedding(config.vocab_size, self.embed_dim) |
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) |
|
|
|
self.drop = nn.Dropout(config.embd_pdrop) |
|
self.h = nn.ModuleList([LoRDCoderBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) |
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) |
|
|
|
max_positions = config.max_position_embeddings |
|
self.register_buffer( |
|
"bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False |
|
) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.wte |
|
|
|
def set_input_embeddings(self, new_embeddings): |
|
self.wte = new_embeddings |
|
|
|
def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask): |
|
""" |
|
Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given. |
|
""" |
|
|
|
if (attention_mask is not None) or (self.config.pad_token_id is None): |
|
return |
|
|
|
|
|
if self.config.pad_token_id in input_ids[:, [-1, 0]]: |
|
warn_string = ( |
|
"We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See " |
|
"https://huggingface.co/docs/transformers/troubleshooting" |
|
"#incorrect-output-when-padding-tokens-arent-masked." |
|
) |
|
|
|
|
|
|
|
if ( |
|
(self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id) |
|
or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id) |
|
or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id) |
|
): |
|
warn_string += ( |
|
f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical " |
|
f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), " |
|
f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded." |
|
) |
|
|
|
print("Warning:", warn_string) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[List[torch.Tensor]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
elif input_ids is not None: |
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) |
|
input_shape = input_ids.size() |
|
input_ids = input_ids.view(-1, input_shape[-1]) |
|
batch_size = input_ids.shape[0] |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
batch_size = inputs_embeds.shape[0] |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
if batch_size <= 0: |
|
raise ValueError("batch_size has to be defined and > 0") |
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
if token_type_ids is not None: |
|
token_type_ids = token_type_ids.view(-1, input_shape[-1]) |
|
if position_ids is not None: |
|
position_ids = position_ids.view(-1, input_shape[-1]) |
|
|
|
if past_key_values is None: |
|
past_length = 0 |
|
past_key_values = tuple([None] * len(self.h)) |
|
else: |
|
past_length = past_key_values[0].size(-2) |
|
|
|
if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
if past_length > 0: |
|
position_ids = position_ids[:, past_length : input_shape[-1] + past_length :] |
|
elif position_ids is None: |
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) |
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) |
|
|
|
|
|
query_length = input_shape[-1] |
|
key_length = past_length + query_length |
|
self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] |
|
|
|
if attention_mask is not None: |
|
self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( |
|
dtype=torch.bool, device=self_attention_mask.device |
|
) |
|
|
|
|
|
|
|
attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) |
|
|
|
|
|
|
|
if ( |
|
self.config.add_cross_attention |
|
and encoder_hidden_states is not None |
|
and encoder_attention_mask is not None |
|
): |
|
if encoder_attention_mask.dim() == 2: |
|
encoder_attention_mask.unsqueeze(1) |
|
assert encoder_attention_mask.dim() == 3 |
|
encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1) |
|
else: |
|
encoder_attention_mask = None |
|
|
|
|
|
|
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer) |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.wte(input_ids) |
|
position_embeds = self.wpe(position_ids) |
|
hidden_states = inputs_embeds + position_embeds |
|
|
|
if token_type_ids is not None: |
|
token_type_embeds = self.wte(token_type_ids) |
|
hidden_states = hidden_states + token_type_embeds |
|
|
|
hidden_states = self.drop(hidden_states) |
|
|
|
output_shape = input_shape + (hidden_states.size(-1),) |
|
|
|
presents = [] if use_cache else None |
|
all_self_attentions = () if output_attentions else None |
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None |
|
all_hidden_states = () if output_hidden_states else None |
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): |
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
|
|
return module(*inputs, use_cache, output_attentions) |
|
|
|
return custom_forward |
|
|
|
outputs = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
None, |
|
attention_mask, |
|
head_mask[i], |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
) |
|
else: |
|
outputs = block( |
|
hidden_states, |
|
layer_past=layer_past, |
|
attention_mask=attention_mask, |
|
head_mask=head_mask[i], |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
if use_cache: |
|
presents.append(outputs[1]) |
|
|
|
if output_attentions: |
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) |
|
if self.config.add_cross_attention: |
|
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) |
|
|
|
hidden_states = self.ln_f(hidden_states) |
|
|
|
hidden_states = hidden_states.view(output_shape) |
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] |
|
if v is not None |
|
) |
|
|
|
return BaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
past_key_values=presents, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attentions, |
|
cross_attentions=all_cross_attentions, |
|
) |
|
|
|
|
|
class LoRDCoderForCausalLM(LoRDCoderPreTrainedModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.transformer = LoRDCoderModel(config) |
|
self.lm_head = lambda x: torch.matmul(x, self.transformer.wte.weight.T) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
raise NotImplementedError("Cannot resize the embeddings of LoRDCoderForCausalLM.") |
|
|
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): |
|
token_type_ids = kwargs.get("token_type_ids", None) |
|
|
|
if past_key_values: |
|
input_ids = input_ids[:, -1].unsqueeze(-1) |
|
if token_type_ids is not None: |
|
token_type_ids = token_type_ids[:, -1].unsqueeze(-1) |
|
|
|
attention_mask = kwargs.get("attention_mask", None) |
|
position_ids = kwargs.get("position_ids", None) |
|
|
|
if attention_mask is not None and position_ids is None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
if past_key_values: |
|
position_ids = position_ids[:, -1].unsqueeze(-1) |
|
else: |
|
position_ids = None |
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
else: |
|
model_inputs = {"input_ids": input_ids} |
|
|
|
model_inputs.update( |
|
{ |
|
"past_key_values": past_key_values, |
|
"use_cache": kwargs.get("use_cache"), |
|
"position_ids": position_ids, |
|
"attention_mask": attention_mask, |
|
"token_type_ids": token_type_ids, |
|
} |
|
) |
|
return model_inputs |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: |
|
r""" |
|
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set |
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` |
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
transformer_outputs = self.transformer( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_states = transformer_outputs[0] |
|
|
|
lm_logits = self.lm_head(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
shift_logits = lm_logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous().to(shift_logits.device) |
|
|
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (lm_logits,) + transformer_outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
loss=loss, |
|
logits=lm_logits, |
|
past_key_values=transformer_outputs.past_key_values, |
|
hidden_states=transformer_outputs.hidden_states, |
|
attentions=transformer_outputs.attentions, |
|
cross_attentions=transformer_outputs.cross_attentions, |
|
) |
|
|
|
@staticmethod |
|
def _reorder_cache( |
|
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor |
|
) -> Tuple[Tuple[torch.Tensor]]: |
|
""" |
|
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or |
|
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct |
|
beam_idx at every generation step. |
|
""" |
|
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values) |
|
|