|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
import torch.nn.functional as F |
|
|
|
from transformers.activations import ACT2FN |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPast, |
|
CausalLMOutputWithPast, |
|
) |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import logging |
|
from .configuration_progen import ProGenConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
def fixed_pos_embedding(x, seq_dim=1, seq_len=None): |
|
dim = x.shape[-1] |
|
if seq_len is None: |
|
seq_len = x.shape[seq_dim] |
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) |
|
sinusoid_inp = ( |
|
torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq) |
|
.to(x.device) |
|
.float() |
|
) |
|
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) |
|
|
|
|
|
def rotate_every_two(x: torch.Tensor): |
|
x1 = x[:, :, :, ::2] |
|
x2 = x[:, :, :, 1::2] |
|
x = torch.stack((-x2, x1), axis=-1) |
|
return x.flatten(-2) |
|
|
|
def apply_rotary_pos_emb(x, sincos, offset=0): |
|
sin, cos = map( |
|
lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave( |
|
2, 3 |
|
), |
|
sincos, |
|
) |
|
|
|
return (x * cos) + (rotate_every_two(x) * sin) |
|
|
|
|
|
class ProGenAttention(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
|
|
max_positions = config.n_positions |
|
self.register_buffer( |
|
"bias", |
|
torch.tril( |
|
torch.ones((max_positions, max_positions), dtype=torch.bool) |
|
).view(1, 1, max_positions, max_positions), |
|
persistent=False |
|
) |
|
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) |
|
|
|
self.attn_dropout = nn.Dropout(config.attn_pdrop) |
|
self.resid_dropout = nn.Dropout(config.resid_pdrop) |
|
|
|
self.embed_dim = config.embed_dim |
|
self.num_attention_heads = config.n_head |
|
self.head_dim = self.embed_dim // self.num_attention_heads |
|
if self.head_dim * self.num_attention_heads != self.embed_dim: |
|
raise ValueError( |
|
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})." |
|
) |
|
self.scale_attn = torch.sqrt( |
|
torch.tensor(self.head_dim, dtype=torch.float32) |
|
).to(torch.get_default_dtype()) |
|
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False) |
|
|
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) |
|
self.rotary_dim = None |
|
if config.rotary_dim is not None: |
|
self.rotary_dim = config.rotary_dim |
|
|
|
def _split_heads(self, x: torch.Tensor, n_head, dim_head) -> torch.Tensor: |
|
x = x.reshape(x.shape[:-2] + (-1,)) |
|
x = x.reshape(x.shape[:-1] + (n_head, dim_head)) |
|
return x |
|
|
|
def _merge_heads(self, tensor, num_attention_heads, attn_head_size) -> torch.Tensor: |
|
""" |
|
Merges attn_head_size dim and num_attn_heads dim into n_positions |
|
""" |
|
if len(tensor.shape) == 5: |
|
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() |
|
elif len(tensor.shape) == 4: |
|
tensor = tensor.permute(0, 2, 1, 3).contiguous() |
|
else: |
|
raise ValueError( |
|
f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}" |
|
) |
|
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) |
|
return tensor.view(new_shape) |
|
|
|
def _attn( |
|
self, |
|
query, |
|
key, |
|
value, |
|
attention_mask=None, |
|
head_mask=None, |
|
): |
|
|
|
query_length, key_length = query.size(-2), key.size(-2) |
|
causal_mask = self.bias[ |
|
:, :, key_length - query_length : key_length, :key_length |
|
] |
|
|
|
|
|
query = query.to(torch.float32) |
|
key = key.to(torch.float32) |
|
|
|
attn_weights = query @ key.transpose(-1, -2) |
|
|
|
attn_weights = attn_weights / self.scale_attn |
|
|
|
|
|
attn_weights = torch.where( |
|
causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype) |
|
) |
|
|
|
if attention_mask is not None: |
|
attn_weights = attn_weights + attention_mask |
|
|
|
attn_weights = F.softmax(attn_weights, dim=-1) |
|
attn_weights = attn_weights.to(value.dtype) |
|
attn_weights = self.attn_dropout(attn_weights) |
|
|
|
if head_mask is not None: |
|
attn_weights = attn_weights * head_mask |
|
|
|
attn_output = attn_weights @ value |
|
|
|
return attn_output, attn_weights |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
layer_past=None, |
|
head_mask=None, |
|
use_cache=False, |
|
output_attentions=False, |
|
): |
|
qkv = self.qkv_proj(hidden_states) |
|
|
|
mp_num = 8 |
|
qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1)) |
|
|
|
query, value, key = torch.split(qkv_split, self.embed_dim // mp_num, dim=-1) |
|
|
|
query = self._split_heads(query, self.num_attention_heads, self.head_dim) |
|
key = self._split_heads(key, self.num_attention_heads, self.head_dim) |
|
value = self._split_heads(value, self.num_attention_heads, self.head_dim) |
|
value = value.permute(0, 2, 1, 3) |
|
|
|
seq_len = key.shape[1] |
|
offset = 0 |
|
|
|
if layer_past is not None: |
|
offset = layer_past[0].shape[-2] |
|
seq_len += offset |
|
|
|
if self.rotary_dim is not None: |
|
k_rot = key[:, :, :, : self.rotary_dim] |
|
k_pass = key[:, :, :, self.rotary_dim :] |
|
|
|
q_rot = query[:, :, :, : self.rotary_dim] |
|
q_pass = query[:, :, :, self.rotary_dim :] |
|
|
|
sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) |
|
k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) |
|
q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) |
|
|
|
key = torch.cat([k_rot, k_pass], dim=-1) |
|
query = torch.cat([q_rot, q_pass], dim=-1) |
|
else: |
|
sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) |
|
key = apply_rotary_pos_emb(key, sincos, offset=offset) |
|
query = apply_rotary_pos_emb(query, sincos, offset=offset) |
|
|
|
key = key.permute(0, 2, 1, 3) |
|
query = query.permute(0, 2, 1, 3) |
|
|
|
if layer_past is not None: |
|
past_key = layer_past[0] |
|
past_value = layer_past[1] |
|
key = torch.cat((past_key, key), dim=-2) |
|
value = torch.cat((past_value, value), dim=-2) |
|
|
|
if use_cache is True: |
|
present = (key, value) |
|
else: |
|
present = None |
|
|
|
|
|
attn_output, attn_weights = self._attn( |
|
query, key, value, attention_mask, head_mask |
|
) |
|
|
|
attn_output = self._merge_heads( |
|
attn_output, self.num_attention_heads, self.head_dim |
|
) |
|
|
|
attn_output = self.out_proj(attn_output) |
|
attn_output = self.resid_dropout(attn_output) |
|
|
|
outputs = (attn_output, present) |
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
class ProGenMLP(nn.Module): |
|
def __init__( |
|
self, intermediate_size, config |
|
): |
|
super().__init__() |
|
embed_dim = config.embed_dim |
|
|
|
self.fc_in = nn.Linear(embed_dim, intermediate_size) |
|
self.fc_out = nn.Linear(intermediate_size, embed_dim) |
|
|
|
self.act = ACT2FN[config.activation_function] |
|
self.dropout = nn.Dropout(config.resid_pdrop) |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.fc_in(hidden_states) |
|
hidden_states = self.act(hidden_states) |
|
hidden_states = self.fc_out(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class ProGenBlock(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.embed_dim |
|
self.ln_1 = nn.LayerNorm(config.embed_dim, eps=config.layer_norm_epsilon) |
|
self.attn = ProGenAttention(config) |
|
self.mlp = ProGenMLP(inner_dim, config) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
layer_past=None, |
|
attention_mask=None, |
|
head_mask=None, |
|
use_cache=False, |
|
output_attentions=False, |
|
): |
|
residual = hidden_states |
|
hidden_states = self.ln_1(hidden_states) |
|
attn_outputs = self.attn( |
|
hidden_states, |
|
layer_past=layer_past, |
|
attention_mask=attention_mask, |
|
head_mask=head_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
attn_output = attn_outputs[0] |
|
outputs = attn_outputs[1:] |
|
|
|
feed_forward_hidden_states = self.mlp(hidden_states) |
|
hidden_states = attn_output + feed_forward_hidden_states + residual |
|
|
|
if use_cache: |
|
outputs = (hidden_states,) + outputs |
|
else: |
|
outputs = (hidden_states,) + outputs[1:] |
|
|
|
return outputs |
|
|
|
|
|
class ProGenPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = ProGenConfig |
|
base_model_prefix = "transformer" |
|
is_parallelizable = False |
|
|
|
def __init__(self, *inputs, **kwargs): |
|
super().__init__(*inputs, **kwargs) |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights.""" |
|
if isinstance(module, (nn.Linear,)): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
class ProGenModel(ProGenPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.vocab_size_emb = config.vocab_size_emb |
|
self.embed_dim = config.embed_dim |
|
self.wte = nn.Embedding(config.vocab_size_emb, self.embed_dim) |
|
self.drop = nn.Dropout(config.embd_pdrop) |
|
self.h = nn.ModuleList([ProGenBlock(config) for _ in range(config.n_layer)]) |
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) |
|
self.rotary_dim = min( |
|
config.rotary_dim, config.n_positions // config.n_head |
|
) |
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
past_key_values=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError( |
|
"You cannot specify both input_ids and inputs_embeds at the same time" |
|
) |
|
elif input_ids is not None: |
|
input_shape = input_ids.size() |
|
input_ids = input_ids.view(-1, input_shape[-1]) |
|
batch_size = input_ids.shape[0] |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
batch_size = inputs_embeds.shape[0] |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
if token_type_ids is not None: |
|
token_type_ids = token_type_ids.view(-1, input_shape[-1]) |
|
|
|
if position_ids is not None: |
|
position_ids = position_ids.view(-1, input_shape[-1]) |
|
|
|
if past_key_values is None: |
|
past_length = 0 |
|
past_key_values = tuple([None] * len(self.h)) |
|
else: |
|
past_length = past_key_values[0][0].size(-2) |
|
|
|
if position_ids is None: |
|
position_ids = torch.arange( |
|
past_length, |
|
input_shape[-1] + past_length, |
|
dtype=torch.long, |
|
device=device, |
|
) |
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) |
|
|
|
|
|
if attention_mask is not None: |
|
assert batch_size > 0, "batch_size has to be defined and > 0" |
|
attention_mask = attention_mask.view(batch_size, -1) |
|
|
|
|
|
|
|
|
|
|
|
attention_mask = attention_mask[:, None, None, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
attention_mask = attention_mask.to(dtype=self.dtype) |
|
attention_mask = (1.0 - attention_mask) * -10000.0 |
|
|
|
|
|
|
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer) |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.wte(input_ids) |
|
|
|
hidden_states = inputs_embeds |
|
|
|
if token_type_ids is not None: |
|
token_type_embeds = self.wte(token_type_ids) |
|
hidden_states = hidden_states + token_type_embeds |
|
|
|
hidden_states = self.drop(hidden_states) |
|
|
|
output_shape = input_shape + (hidden_states.size(-1),) |
|
|
|
presents = () if use_cache else None |
|
all_self_attentions = () if output_attentions else None |
|
all_hidden_states = () if output_hidden_states else None |
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): |
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if getattr(self.config, "gradient_checkpointing", False) and self.training: |
|
if use_cache: |
|
logger.warning( |
|
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " |
|
"`use_cache=False`..." |
|
) |
|
use_cache = False |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
|
|
return module(*inputs, use_cache, output_attentions) |
|
|
|
return custom_forward |
|
|
|
outputs = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
None, |
|
attention_mask, |
|
head_mask[i], |
|
) |
|
else: |
|
outputs = block( |
|
hidden_states, |
|
layer_past=layer_past, |
|
attention_mask=attention_mask, |
|
head_mask=head_mask[i], |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
if use_cache is True: |
|
presents = presents + (outputs[1],) |
|
|
|
if output_attentions: |
|
all_self_attentions = all_self_attentions + ( |
|
outputs[2 if use_cache else 1], |
|
) |
|
|
|
hidden_states = self.ln_f(hidden_states) |
|
|
|
hidden_states = hidden_states.view(*output_shape) |
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [ |
|
hidden_states, |
|
presents, |
|
all_hidden_states, |
|
all_self_attentions, |
|
] |
|
if v is not None |
|
) |
|
|
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=presents, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attentions, |
|
) |
|
|
|
|
|
class ProGenForCausalLM(ProGenPreTrainedModel): |
|
_keys_to_ignore_on_load_missing = [ |
|
r"h\.\d+\.attn\.masked_bias", |
|
r"h\.\d+\.attn\.bias", |
|
r"lm_head\.weight", |
|
] |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.transformer = ProGenModel(config) |
|
self.lm_head = nn.Linear(config.embed_dim, config.vocab_size_lm_head) |
|
self.init_weights() |
|
|
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): |
|
token_type_ids = kwargs.get("token_type_ids", None) |
|
|
|
if past: |
|
input_ids = input_ids[:, -1].unsqueeze(-1) |
|
if token_type_ids is not None: |
|
token_type_ids = token_type_ids[:, -1].unsqueeze(-1) |
|
|
|
attention_mask = kwargs.get("attention_mask", None) |
|
position_ids = kwargs.get("position_ids", None) |
|
|
|
if attention_mask is not None and position_ids is None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
if past: |
|
position_ids = position_ids[:, -1].unsqueeze(-1) |
|
else: |
|
position_ids = None |
|
return { |
|
"input_ids": input_ids, |
|
"past_key_values": past, |
|
"use_cache": kwargs.get("use_cache"), |
|
"position_ids": position_ids, |
|
"attention_mask": attention_mask, |
|
"token_type_ids": token_type_ids, |
|
} |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
past_key_values=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
r""" |
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set |
|
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to |
|
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` |
|
""" |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
transformer_outputs = self.transformer( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_states = transformer_outputs[0] |
|
|
|
|
|
|
|
|
|
lm_logits = self.lm_head(hidden_states).to(torch.float32) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
shift_logits = lm_logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct( |
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |
|
) |
|
loss = loss.to(hidden_states.dtype) |
|
|
|
if not return_dict: |
|
output = (lm_logits,) + transformer_outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=lm_logits, |
|
past_key_values=transformer_outputs.past_key_values, |
|
hidden_states=transformer_outputs.hidden_states, |
|
attentions=transformer_outputs.attentions, |
|
) |
|
|
|
@staticmethod |
|
def _reorder_cache( |
|
past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor |
|
) -> Tuple[Tuple[torch.Tensor]]: |
|
""" |
|
This function is used to re-order the :obj:`past_key_values` cache if |
|
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is |
|
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. |
|
""" |
|
return tuple( |
|
tuple( |
|
past_state.index_select(0, beam_idx.to(past_state.device)) |
|
for past_state in layer_past |
|
) |
|
for layer_past in past |
|
) |
|
|