|
|
|
|
|
|
|
import torch.nn as nn |
|
import torch |
|
import torch.nn.functional as F |
|
from datetime import datetime |
|
from concurrent.futures import ThreadPoolExecutor |
|
from dataclasses import dataclass |
|
from typing import Optional |
|
from transformers import PreTrainedModel, PretrainedConfig, GPT2Config |
|
from transformers.modeling_outputs import MaskedLMOutput |
|
from lmup.gpst.gpt2_flash_attn import GPT2Model |
|
|
|
from lmup.gpst.r2d2_insideoutside import * |
|
import copy |
|
|
|
def load_model(model, model_path, strict=True): |
|
state_dict = torch.load(model_path, map_location=lambda a, b: a) |
|
transfered_state_dict = {} |
|
for k, v in state_dict.items(): |
|
new_k = k.replace('module.', '') |
|
transfered_state_dict[new_k] = v |
|
model.load_state_dict(transfered_state_dict, strict=strict) |
|
|
|
def index_sanity_hook(module, input, output): |
|
def check(tensor): |
|
if isinstance(tensor, torch.Tensor): |
|
if tensor.dtype == torch.long or "int" in str(tensor.dtype): |
|
if tensor.max() > 10000 or tensor.min() < -10000: |
|
print(f"[!] Suspicious index in {module.__class__.__name__}: min={tensor.min().item()}, max={tensor.max().item()}, shape={tensor.shape}") |
|
|
|
|
|
for item in input: |
|
if isinstance(item, (tuple, list)): |
|
for sub in item: |
|
check(sub) |
|
else: |
|
check(item) |
|
|
|
@dataclass(kw_only=True) |
|
class R2D2GenOutput(): |
|
struct_loss: Optional[torch.FloatTensor] = None, |
|
non_struct_loss: Optional[torch.FloatTensor] = None, |
|
non_struct_loss_fullscale: Optional[torch.FloatTensor] = None, |
|
action_logits: Optional[torch.FloatTensor] = None, |
|
hidden_states: Optional[torch.FloatTensor] = None, |
|
cls_hidden_states: Optional[torch.FloatTensor] = None, |
|
tgt_ids: Optional[torch.LongTensor] = None, |
|
pred: Optional[torch.FloatTensor] = None, |
|
splits: Optional[torch.LongTensor] = None, |
|
gpt_loss: Optional[torch.FloatTensor] = None, |
|
action_loss: Optional[torch.FloatTensor] = None, |
|
inside_outside_loss: Optional[torch.FloatTensor] = None, |
|
parser_loss: Optional[torch.FloatTensor] = None, |
|
glue_finetune_loss: Optional[torch.FloatTensor] = None, |
|
past_kv: Optional[torch.FloatTensor] = None |
|
logits: Optional[torch.FloatTensor] = None, |
|
loss: Optional[torch.FloatTensor] = None, |
|
|
|
|
|
class GPSTConfig(PretrainedConfig): |
|
model_type = "gpst" |
|
|
|
def __init__(self, r2d2=None, gpt=None, **kwargs): |
|
|
|
self.gptconfig = gpt |
|
self.r2d2config = r2d2 |
|
super().__init__(**kwargs) |
|
|
|
class GPST(PreTrainedModel): |
|
config_class = GPSTConfig |
|
|
|
def __init__(self, config, gradient_checkpoint=False): |
|
super().__init__(config) |
|
self.config = config |
|
self.r2d2_config = PretrainedConfig.from_dict(config.r2d2config) |
|
self.gpt_config = GPT2Config.from_dict(config.gptconfig) |
|
|
|
self.vocab_size = self.gpt_config.vocab_size |
|
|
|
total_layer = self.gpt_config.n_layer |
|
action_transformers = GPT2Model(copy.deepcopy(self.gpt_config), no_embedding=True, no_layer_norm=True, n_layers_manual=self.gpt_config.action_layer_num) |
|
action_transformers.gradient_checkpointing = gradient_checkpoint |
|
self.gpt_config.n_layer = total_layer - self.gpt_config.action_layer_num |
|
self.gpt_config.num_hidden_layers = total_layer - self.gpt_config.action_layer_num |
|
gpt_transformers = GPT2Model(self.gpt_config, no_embedding=True, no_extra_embedding=True) |
|
gpt_transformers.gradient_checkpointing = gradient_checkpoint |
|
|
|
r2d2 = InsideOutsideModule(self.r2d2_config) |
|
self.model = FastGenerativeR2D2( |
|
r2d2=r2d2, |
|
action_layers=action_transformers, |
|
generation_layers=gpt_transformers, |
|
vocab_size=self.vocab_size, |
|
r2d2_input_dim=r2d2.input_dim, |
|
embedding_dim=self.gpt_config.n_embd, |
|
ext_vocab_size=self.r2d2_config.ext_vocab_size, |
|
dense_hidden_factor=self.gpt_config.dense_hidden_factor |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def get_input_embeddings(self): |
|
return self.model.embeddings |
|
|
|
def forward(self, **kwargs): |
|
return self.model(**kwargs) |
|
|
|
class FastGenerativeR2D2(nn.Module): |
|
def __init__(self, r2d2, action_layers, generation_layers, vocab_size, |
|
r2d2_input_dim, embedding_dim, dropout_rate=0.2, ext_vocab_size=0, |
|
fix_embeddings=False, dense_hidden_factor=4): |
|
|
|
|
|
super().__init__() |
|
self.embedding_dim = embedding_dim |
|
self.r2d2_input_dim = r2d2_input_dim |
|
self.r2d2 = r2d2 |
|
|
|
self.vocab_size = vocab_size |
|
|
|
|
|
|
|
self.enable_gpt = False |
|
if action_layers is not None and generation_layers is not None: |
|
self.dense_hidden_factor = dense_hidden_factor |
|
self.action_layers = action_layers |
|
self.generation_layers = generation_layers |
|
self.bos_embedding = nn.Parameter(torch.rand(self.embedding_dim)) |
|
self.up_scale = nn.Linear(self.r2d2_input_dim, self.embedding_dim) |
|
self.dense = nn.Sequential(nn.Linear(self.embedding_dim, self.dense_hidden_factor * self.embedding_dim), |
|
nn.GELU(), |
|
nn.Dropout(dropout_rate), |
|
nn.Linear(self.dense_hidden_factor * self.embedding_dim, self.embedding_dim)) |
|
self.action_mlp = nn.Sequential(nn.LayerNorm(self.embedding_dim), |
|
nn.Linear(self.embedding_dim, self.embedding_dim), |
|
nn.GELU(), |
|
nn.Dropout(dropout_rate), |
|
nn.Linear(self.embedding_dim, 2)) |
|
self.enable_gpt = True |
|
|
|
self.classifier = nn.Linear(self.embedding_dim, vocab_size, bias=False) |
|
self.embeddings = nn.Embedding(vocab_size, self.embedding_dim) |
|
self.embeddings.requires_grad = not fix_embeddings |
|
self.down_scale = nn.Linear(self.embedding_dim, self.r2d2_input_dim) |
|
|
|
self.insideoutside_dense = nn.Sequential( |
|
nn.Linear(r2d2_input_dim, self.dense_hidden_factor * r2d2_input_dim), |
|
nn.GELU(), |
|
nn.Dropout(dropout_rate), |
|
nn.Linear(self.dense_hidden_factor * r2d2_input_dim, self.embedding_dim) |
|
) |
|
|
|
|
|
|
|
self._init_weights() |
|
self._tie_weights() |
|
|
|
def _init_weights(self): |
|
if self.enable_gpt: |
|
self.bos_embedding.data.normal_(mean=0, std=0.02) |
|
self.embeddings.weight.data.normal_(mean=0, std=0.02) |
|
|
|
def _tie_weights(self): |
|
self.classifier.weight = self.embeddings.weight |
|
|
|
def get_parser(self): |
|
return self.r2d2.parser |
|
|
|
def from_pretrain(self, model_path, strict=True): |
|
load_model(self, model_path, strict=strict) |
|
self._tie_weights() |
|
|
|
def _append_eos_label(self, eos_labels, chunk_input_ids, chunk_masks, next_token_indices, max_input_len): |
|
chunk_masks = (chunk_masks.sum(dim=1) > 0).to(int) |
|
seq_lens = chunk_masks.sum(dim=1) |
|
temp_ids = torch.zeros((chunk_input_ids.shape[0], chunk_input_ids.shape[1] + 1), dtype=chunk_input_ids.dtype, device=chunk_input_ids.device) |
|
temp_ids.fill_(-100) |
|
temp_ids[:, :-1] = chunk_input_ids |
|
|
|
temp_ids.scatter_(1, seq_lens.unsqueeze(1), torch.tensor(eos_labels, device=chunk_input_ids.device).unsqueeze(1)) |
|
chunk_input_ids = temp_ids |
|
next_token_indices = next_token_indices[:, :max_input_len + 1] |
|
return next_token_indices, chunk_input_ids |
|
|
|
def forward(self, chunk_input_ids= None, chunk_masks=None, input_ids=None, masks=None, eos_labels=None, group_ids=None, |
|
atom_spans=None, span_ids=None, external_vocab_ids=None, |
|
coeff=1.0, temperature=1.0, past_key_values=None): |
|
batch_size = max(group_ids) + 1 |
|
r2d2_input_ids = torch.where(chunk_input_ids == -100, 0, chunk_input_ids) |
|
input_embeddings = self.embeddings(r2d2_input_ids) |
|
r2d2_embeddings = self.down_scale(input_embeddings) |
|
|
|
max_input_len = (chunk_masks != 0).sum(dim=1).max().to('cpu', non_blocking=True) |
|
|
|
ctx, outside_tgt, ldr_repr, position_ids, tgt_ids, token_indices, ext_ids, split_targets, l_height = \ |
|
self.r2d2(r2d2_input_ids, chunk_masks, input_ids, masks, r2d2_embeddings, group_ids, |
|
max_input_len, atom_spans=atom_spans, coeff=coeff, temperature=temperature, span_ids=span_ids, |
|
eos_labels=eos_labels, external_vocab_ids=external_vocab_ids) |
|
|
|
|
|
if self.training: |
|
|
|
parser_loss = self.r2d2.parser_loss(ctx) |
|
outside_embeddings = self.r2d2.outside_embeddings(ctx) |
|
io_dense = self.insideoutside_dense(outside_embeddings) |
|
outside_logits = self.classifier(io_dense) |
|
insideoutside_loss = F.cross_entropy(outside_logits, outside_tgt) |
|
else: |
|
parser_loss = insideoutside_loss = 0 |
|
|
|
logits = action_logits = None |
|
gpt_loss = action_loss = 0 |
|
past_kv = None |
|
hidden_states = None |
|
|
|
if self.enable_gpt: |
|
if past_key_values is not None: |
|
action_past_kv, gen_past_kv = past_key_values |
|
else: |
|
action_past_kv = gen_past_kv = None |
|
gpt_input = self.up_scale(ldr_repr).clone() |
|
gpt_input.scatter_(1, token_indices.unsqueeze(2).repeat(1, 1, input_embeddings.shape[-1]).clone(), |
|
input_embeddings.to(gpt_input.dtype)) |
|
|
|
|
|
|
|
bos_emb = self.bos_embedding.unsqueeze(0).repeat(batch_size, 1) |
|
|
|
cat_input = torch.cat([bos_emb.unsqueeze(1), gpt_input], dim=1) |
|
|
|
|
|
outputs = self.action_layers(inputs_embeds=cat_input, position_ids=position_ids, past_key_values=action_past_kv) |
|
action_logits = self.action_mlp(outputs.last_hidden_state) |
|
|
|
|
|
|
|
action_tgt = torch.where(tgt_ids == self.r2d2.reduce_id, 1, 0) |
|
action_tgt = torch.where(tgt_ids != -1, action_tgt, -1) |
|
|
|
|
|
next_token_indices = (tgt_ids != self.r2d2.reduce_id).int().argsort(dim=-1, descending=True, stable=True) |
|
if eos_labels is None: |
|
|
|
truncated_len = max_input_len |
|
next_token_indices = next_token_indices[:, :truncated_len] |
|
else: |
|
next_token_indices, chunk_input_ids = self._append_eos_label(eos_labels, chunk_input_ids, chunk_masks, next_token_indices, max_input_len) |
|
|
|
|
|
|
|
next_token_indices_reformat = next_token_indices.unsqueeze(2).repeat(1, 1, self.embedding_dim) |
|
generation_inputs = outputs.last_hidden_state.gather(1, next_token_indices_reformat) |
|
|
|
|
|
|
|
token_outputs = self.generation_layers(inputs_embeds=generation_inputs, past_key_values=gen_past_kv) |
|
|
|
hidden_states = token_outputs.last_hidden_state |
|
logits = self.classifier(self.dense(hidden_states)) |
|
|
|
|
|
|
|
gpt_loss = F.cross_entropy(logits.permute(0, 2, 1), chunk_input_ids, ignore_index=-100) |
|
action_loss = F.cross_entropy(action_logits.permute(0, 2, 1), action_tgt, ignore_index=-1) |
|
past_kv = (outputs.past_key_values, token_outputs.past_key_values) |
|
|
|
|
|
|
|
return R2D2GenOutput(struct_loss=insideoutside_loss + l_height, |
|
non_struct_loss=0.5 * gpt_loss + action_loss + parser_loss, |
|
non_struct_loss_fullscale=gpt_loss + action_loss + parser_loss, |
|
logits=logits, |
|
action_logits=action_logits, |
|
hidden_states=hidden_states, |
|
tgt_ids=chunk_input_ids, |
|
gpt_loss=gpt_loss, |
|
action_loss=action_loss, |
|
inside_outside_loss=insideoutside_loss, |
|
parser_loss=parser_loss, |
|
past_kv=past_kv, |
|
splits=split_targets, |
|
loss=action_loss+gpt_loss+parser_loss+insideoutside_loss+l_height) |
|
|
|
|
|
class FastGenerativeR2D2_discriminant_glue(FastGenerativeR2D2): |
|
|
|
def _append_eos_label(self, eos_labels, chunk_input_ids, chunk_masks, next_token_indices, max_input_len): |
|
chunk_masks = (chunk_masks.sum(dim=1) > 0).to(int) |
|
seq_lens = chunk_masks.sum(dim=1) |
|
temp_ids = torch.zeros((chunk_input_ids.shape[0], chunk_input_ids.shape[1] + 1), dtype=chunk_input_ids.dtype, device=chunk_input_ids.device) |
|
temp_ids.fill_(-100) |
|
temp_ids[:, :-1] = chunk_input_ids |
|
|
|
chunk_input_ids = temp_ids |
|
next_token_indices = next_token_indices[:, :max_input_len + 1] |
|
return next_token_indices, chunk_input_ids |