File size: 15,823 Bytes
076214d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 |
# coding=utf-8
# Copyright (c) 2024 Ant Group
# Author: Xiang Hu
import torch.nn as nn
import torch
import torch.nn.functional as F
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Optional
from transformers import PreTrainedModel, PretrainedConfig, GPT2Config
from transformers.modeling_outputs import MaskedLMOutput
from lmup.gpst.gpt2_flash_attn import GPT2Model
#from lmup.gpst.config import *
from lmup.gpst.r2d2_insideoutside import *
import copy
def load_model(model, model_path, strict=True):
state_dict = torch.load(model_path, map_location=lambda a, b: a)
transfered_state_dict = {}
for k, v in state_dict.items():
new_k = k.replace('module.', '')
transfered_state_dict[new_k] = v
model.load_state_dict(transfered_state_dict, strict=strict)
def index_sanity_hook(module, input, output):
def check(tensor):
if isinstance(tensor, torch.Tensor):
if tensor.dtype == torch.long or "int" in str(tensor.dtype):
if tensor.max() > 10000 or tensor.min() < -10000: # change threshold if needed
print(f"[!] Suspicious index in {module.__class__.__name__}: min={tensor.min().item()}, max={tensor.max().item()}, shape={tensor.shape}")
# Check all inputs
for item in input:
if isinstance(item, (tuple, list)):
for sub in item:
check(sub)
else:
check(item)
@dataclass(kw_only=True)
class R2D2GenOutput():
struct_loss: Optional[torch.FloatTensor] = None,
non_struct_loss: Optional[torch.FloatTensor] = None,
non_struct_loss_fullscale: Optional[torch.FloatTensor] = None,
action_logits: Optional[torch.FloatTensor] = None,
hidden_states: Optional[torch.FloatTensor] = None,
cls_hidden_states: Optional[torch.FloatTensor] = None,
tgt_ids: Optional[torch.LongTensor] = None,
pred: Optional[torch.FloatTensor] = None,
splits: Optional[torch.LongTensor] = None,
gpt_loss: Optional[torch.FloatTensor] = None,
action_loss: Optional[torch.FloatTensor] = None,
inside_outside_loss: Optional[torch.FloatTensor] = None,
parser_loss: Optional[torch.FloatTensor] = None,
glue_finetune_loss: Optional[torch.FloatTensor] = None,
past_kv: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None,
loss: Optional[torch.FloatTensor] = None,
class GPSTConfig(PretrainedConfig):
model_type = "gpst"
def __init__(self, r2d2=None, gpt=None, **kwargs):#, gptconfig, r2d2config, **kwargs):
self.gptconfig = gpt
self.r2d2config = r2d2
super().__init__(**kwargs)
class GPST(PreTrainedModel):
config_class = GPSTConfig
def __init__(self, config, gradient_checkpoint=False):
super().__init__(config)
self.config = config
self.r2d2_config = PretrainedConfig.from_dict(config.r2d2config)
self.gpt_config = GPT2Config.from_dict(config.gptconfig)
self.vocab_size = self.gpt_config.vocab_size
total_layer = self.gpt_config.n_layer
action_transformers = GPT2Model(copy.deepcopy(self.gpt_config), no_embedding=True, no_layer_norm=True, n_layers_manual=self.gpt_config.action_layer_num)
action_transformers.gradient_checkpointing = gradient_checkpoint
self.gpt_config.n_layer = total_layer - self.gpt_config.action_layer_num
self.gpt_config.num_hidden_layers = total_layer - self.gpt_config.action_layer_num
gpt_transformers = GPT2Model(self.gpt_config, no_embedding=True, no_extra_embedding=True)
gpt_transformers.gradient_checkpointing = gradient_checkpoint
r2d2 = InsideOutsideModule(self.r2d2_config)
self.model = FastGenerativeR2D2(
r2d2=r2d2,
action_layers=action_transformers,
generation_layers=gpt_transformers,
vocab_size=self.vocab_size,
r2d2_input_dim=r2d2.input_dim,
embedding_dim=self.gpt_config.n_embd,
ext_vocab_size=self.r2d2_config.ext_vocab_size,
dense_hidden_factor=self.gpt_config.dense_hidden_factor
)
#for name, module in self.model.named_modules():
# #module.register_forward_hook(index_sanity_hook)
# module.register_full_backward_hook(index_sanity_hook)
def get_input_embeddings(self):
return self.model.embeddings
def forward(self, **kwargs):
return self.model(**kwargs)
class FastGenerativeR2D2(nn.Module):
def __init__(self, r2d2, action_layers, generation_layers, vocab_size,
r2d2_input_dim, embedding_dim, dropout_rate=0.2, ext_vocab_size=0,
fix_embeddings=False, dense_hidden_factor=4):
# embedding dim is used to feed to r2d2
# input dim is sued to feed to GPT
super().__init__()
self.embedding_dim = embedding_dim # embedding_dim > r2d2_input_dim
self.r2d2_input_dim = r2d2_input_dim
self.r2d2 = r2d2
self.vocab_size = vocab_size
# self.action_ln = nn.Linear(self.embedding_dim, 2) # judge reduce or predict next token
self.enable_gpt = False
if action_layers is not None and generation_layers is not None:
self.dense_hidden_factor = dense_hidden_factor
self.action_layers = action_layers
self.generation_layers = generation_layers
self.bos_embedding = nn.Parameter(torch.rand(self.embedding_dim))
self.up_scale = nn.Linear(self.r2d2_input_dim, self.embedding_dim)
self.dense = nn.Sequential(nn.Linear(self.embedding_dim, self.dense_hidden_factor * self.embedding_dim),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(self.dense_hidden_factor * self.embedding_dim, self.embedding_dim))
self.action_mlp = nn.Sequential(nn.LayerNorm(self.embedding_dim),
nn.Linear(self.embedding_dim, self.embedding_dim),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(self.embedding_dim, 2))
self.enable_gpt = True
self.classifier = nn.Linear(self.embedding_dim, vocab_size, bias=False)
self.embeddings = nn.Embedding(vocab_size, self.embedding_dim)
self.embeddings.requires_grad = not fix_embeddings
self.down_scale = nn.Linear(self.embedding_dim, self.r2d2_input_dim)
self.insideoutside_dense = nn.Sequential(
nn.Linear(r2d2_input_dim, self.dense_hidden_factor * r2d2_input_dim),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(self.dense_hidden_factor * r2d2_input_dim, self.embedding_dim)
)
# self.parallel_stream = torch.cuda.Stream()
self._init_weights()
self._tie_weights()
def _init_weights(self):
if self.enable_gpt:
self.bos_embedding.data.normal_(mean=0, std=0.02)
self.embeddings.weight.data.normal_(mean=0, std=0.02)
def _tie_weights(self):
self.classifier.weight = self.embeddings.weight
def get_parser(self):
return self.r2d2.parser
def from_pretrain(self, model_path, strict=True):
load_model(self, model_path, strict=strict)
self._tie_weights()
def _append_eos_label(self, eos_labels, chunk_input_ids, chunk_masks, next_token_indices, max_input_len):
chunk_masks = (chunk_masks.sum(dim=1) > 0).to(int)
seq_lens = chunk_masks.sum(dim=1) # (N)
temp_ids = torch.zeros((chunk_input_ids.shape[0], chunk_input_ids.shape[1] + 1), dtype=chunk_input_ids.dtype, device=chunk_input_ids.device)
temp_ids.fill_(-100)
temp_ids[:, :-1] = chunk_input_ids
# comment this line to support discriminant way
temp_ids.scatter_(1, seq_lens.unsqueeze(1), torch.tensor(eos_labels, device=chunk_input_ids.device).unsqueeze(1))
chunk_input_ids = temp_ids
next_token_indices = next_token_indices[:, :max_input_len + 1]
return next_token_indices, chunk_input_ids
def forward(self, chunk_input_ids= None, chunk_masks=None, input_ids=None, masks=None, eos_labels=None, group_ids=None,
atom_spans=None, span_ids=None, external_vocab_ids=None,
coeff=1.0, temperature=1.0, past_key_values=None):
batch_size = max(group_ids) + 1
r2d2_input_ids = torch.where(chunk_input_ids == -100, 0, chunk_input_ids)
input_embeddings = self.embeddings(r2d2_input_ids)
r2d2_embeddings = self.down_scale(input_embeddings)
# max_input_len = chunk_input_ids.shape[1]
max_input_len = (chunk_masks != 0).sum(dim=1).max().to('cpu', non_blocking=True)
ctx, outside_tgt, ldr_repr, position_ids, tgt_ids, token_indices, ext_ids, split_targets, l_height = \
self.r2d2(r2d2_input_ids, chunk_masks, input_ids, masks, r2d2_embeddings, group_ids,
max_input_len, atom_spans=atom_spans, coeff=coeff, temperature=temperature, span_ids=span_ids,
eos_labels=eos_labels, external_vocab_ids=external_vocab_ids)
if self.training:
# with torch.cuda.stream(self.parallel_stream):
parser_loss = self.r2d2.parser_loss(ctx)
outside_embeddings = self.r2d2.outside_embeddings(ctx) # (num non-padding tokens in batch) x embedding_dim
io_dense = self.insideoutside_dense(outside_embeddings) # (num non-padding tokens in batch) x embedding_dim
outside_logits = self.classifier(io_dense) # (num non-padding tokens in batch) x voc_size
insideoutside_loss = F.cross_entropy(outside_logits, outside_tgt)
else:
parser_loss = insideoutside_loss = 0
logits = action_logits = None
gpt_loss = action_loss = 0
past_kv = None
hidden_states = None
if self.enable_gpt:
if past_key_values is not None:
action_past_kv, gen_past_kv = past_key_values
else:
action_past_kv = gen_past_kv = None
gpt_input = self.up_scale(ldr_repr).clone() # ldr_repr: batch_size x (2*max_seq_len - 1) x HP dim; gpt_input: batch_size x (2*max_seq_len - 1) x emb_dim
gpt_input.scatter_(1, token_indices.unsqueeze(2).repeat(1, 1, input_embeddings.shape[-1]).clone(),
input_embeddings.to(gpt_input.dtype)) # inserting values of input_embeddings at token_indices
# ext_embedding = self.ext_embeds(ext_ids)
# gpt_input = gpt_input + ext_embedding
bos_emb = self.bos_embedding.unsqueeze(0).repeat(batch_size, 1)
# position ids already considered <bos>
cat_input = torch.cat([bos_emb.unsqueeze(1), gpt_input], dim=1) # batch_size x 2*max_seq_len x emb_dim # old comment: (group_size, L + 1, dim) where L is (2*max_seq_len - 1)
# cat_input = self.layer_norm(cat_input)
# cat_input = self.norm(cat_input)
outputs = self.action_layers(inputs_embeds=cat_input, position_ids=position_ids, past_key_values=action_past_kv) # (B, L, dim)
action_logits = self.action_mlp(outputs.last_hidden_state) # (batch_size x 2*max_seq_len x emb_dim) to (B, 2*max_seq_len, 2)
# before, tgt_ids has shape (batch_size x 2*max_seq_len) and has the reduce_token_id (default: 50257) in some placdse. Paddingn: -1
# first where expression: action_tgt has shape (batch_size x 2*max_seq_len), 0 and 1 values for generate/reduce. Padding is still zero
# second where: padding with -1
action_tgt = torch.where(tgt_ids == self.r2d2.reduce_id, 1, 0) # REDUCE: 1, SHIFT:0
action_tgt = torch.where(tgt_ids != -1, action_tgt, -1)
# print(action_tgt)
next_token_indices = (tgt_ids != self.r2d2.reduce_id).int().argsort(dim=-1, descending=True, stable=True) # (B, L)
if eos_labels is None:
#truncated_len = max_input_len if self.training else max_input_len + 1
truncated_len = max_input_len
next_token_indices = next_token_indices[:, :truncated_len] # batch_size x longest_sequence;
else:
next_token_indices, chunk_input_ids = self._append_eos_label(eos_labels, chunk_input_ids, chunk_masks, next_token_indices, max_input_len)
# outputs.last_hidden_state: batch_size x 2*max_seq_len x HP dim
# generation_inputs: batch_size x max_seq_len x HP dim
# next_token_indices_reformat: batch_size x max_seq_len x HP dim
next_token_indices_reformat = next_token_indices.unsqueeze(2).repeat(1, 1, self.embedding_dim)
generation_inputs = outputs.last_hidden_state.gather(1, next_token_indices_reformat)
# token_pos_ids = position_ids.gather(1, next_token_indices)
# gather outputs to predict the next token
# token_outputs: LM output. Last hidden_state has batch_size x max_seq_len x HP dim
token_outputs = self.generation_layers(inputs_embeds=generation_inputs, past_key_values=gen_past_kv)
hidden_states = token_outputs.last_hidden_state
logits = self.classifier(self.dense(hidden_states)) # new: batch_size x max_seq_len x voc_size old: (group_size, L + 1, vocab)
# predict token loss + action loss
# print("chunk_input_ids: ", chunk_input_ids)
#if self.training:
gpt_loss = F.cross_entropy(logits.permute(0, 2, 1), chunk_input_ids, ignore_index=-100)
action_loss = F.cross_entropy(action_logits.permute(0, 2, 1), action_tgt, ignore_index=-1)
past_kv = (outputs.past_key_values, token_outputs.past_key_values)
# torch.cuda.synchronize()
# return loss + lm_loss + parser_loss, split_targets
return R2D2GenOutput(struct_loss=insideoutside_loss + l_height,
non_struct_loss=0.5 * gpt_loss + action_loss + parser_loss,
non_struct_loss_fullscale=gpt_loss + action_loss + parser_loss,
logits=logits,
action_logits=action_logits,
hidden_states=hidden_states,
tgt_ids=chunk_input_ids,
gpt_loss=gpt_loss,
action_loss=action_loss,
inside_outside_loss=insideoutside_loss,
parser_loss=parser_loss,
past_kv=past_kv,
splits=split_targets,
loss=action_loss+gpt_loss+parser_loss+insideoutside_loss+l_height)
# parser_loss should be fine
class FastGenerativeR2D2_discriminant_glue(FastGenerativeR2D2):
def _append_eos_label(self, eos_labels, chunk_input_ids, chunk_masks, next_token_indices, max_input_len):
chunk_masks = (chunk_masks.sum(dim=1) > 0).to(int)
seq_lens = chunk_masks.sum(dim=1) # (N)
temp_ids = torch.zeros((chunk_input_ids.shape[0], chunk_input_ids.shape[1] + 1), dtype=chunk_input_ids.dtype, device=chunk_input_ids.device)
temp_ids.fill_(-100)
temp_ids[:, :-1] = chunk_input_ids
# temp_ids.scatter_(1, seq_lens.unsqueeze(1), torch.tensor(eos_labels, device=chunk_input_ids.device).unsqueeze(1))
chunk_input_ids = temp_ids
next_token_indices = next_token_indices[:, :max_input_len + 1]
return next_token_indices, chunk_input_ids |