IE101TW / models /sequence_classification /masked_prompt_cls.py
DeepLearning101's picture
Upload 4 files
b0ebb46
raw
history blame
74.7 kB
"""Custom models for few-shot learning specific operations."""
import torch
import torch.nn as nn
import transformers
import torch.nn.functional as F
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction
from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertForSequenceClassification, BertModel, BertOnlyMLMHead
from transformers.models.roberta.modeling_roberta import RobertaForSequenceClassification, RobertaModel, RobertaLMHead, RobertaClassificationHead, RobertaPreTrainedModel
from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2PreTrainedModel, DebertaV2Model, StableDropout, ContextPooler, DebertaV2OnlyMLMHead
from transformers.models.deberta.modeling_deberta import DebertaPreTrainedModel, DebertaModel, StableDropout, ContextPooler, DebertaOnlyMLMHead
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bert.configuration_bert import BertConfig
import logging
from models.basic_modules.adapter import RobertaAdaModel, BertAdaModel
import os
from models.basic_modules.prefix_encoder import PrefixEncoder
from tools.model_utils.parameter_freeze import ParameterFreeze
freezer = ParameterFreeze()
logger = logging.getLogger(__name__)
# Note: 如果mask_pos为None,请检查输入的模板是否有<mask>标记,是否修改data_collator文件
"""
Vanilla Prompt-tuning BERT
"""
class PromptBertForSequenceClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.pre_seq_len = self.config.pre_seq_len
self.hidden_size = self.config.hidden_size
# backbone
self.bert = BertModel(config)
if self.config.use_freezing:
self.bert = freezer.freeze_lm(self.bert)
# mlm head
self.cls = BertOnlyMLMHead(config)
self.init_weights()
# These attributes should be assigned once the model is initialized
self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
# For regression
self.lb = None
self.ub = None
# For label search.
self.return_full_softmax = None
def freeze_backbone(self, use_freezing: bool=True):
if use_freezing:
self.bert = freezer.freeze_lm(self.bert)
else:
self.bert = freezer.unfreeze_lm(self.bert)
def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
"""
Encoding and obtain logits at masked position
"""
if mask_pos is not None:
mask_pos = mask_pos.squeeze()
# Encode everything
if inputs_embeds is None:
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
else:
outputs = self.bert(
None,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds
)
# Get <mask> token representation
sequence_output, pooled_output = outputs[:2]
sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# Logits over vocabulary tokens
prediction_mask_scores = self.cls(sequence_mask_output)
# Exit early and only return mask logits.
if return_full_softmax:
return prediction_mask_scores
# Return logits for each label
logits = []
for label_id in range(len(self.label_word_list)):
logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
logits = torch.cat(logits, -1)
# Regression task
if self.config.num_labels == 1:
logsoftmax = nn.LogSoftmax(-1)
logits = logsoftmax(logits) # Log prob of right polarity
return logits, sequence_mask_output
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
mask_pos=None,
labels=None,
inputs_embeds=None,
block_flag=None,
return_dict=None,
):
logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
loss = None
if labels is not None:
if self.num_labels == 1:
# Regression task
loss_fct = nn.KLDivLoss(log_target=True)
labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
loss = loss_fct(logits.view(-1, 2), labels)
else:
if labels.shape == logits.shape:
loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
labels, reduction="batchmean")
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
output = (logits,)
if self.num_labels == 1:
# Regression output
output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
if not return_dict:
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
"""
P-tuning BERT
"""
class PromptBertPtuningForSequenceClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.pre_seq_len = self.config.pre_seq_len
self.hidden_size = self.config.hidden_size
# backbone
self.bert = BertModel(config)
if self.config.use_freezing:
self.bert = freezer.freeze_lm(self.bert)
# mlm head
self.cls = BertOnlyMLMHead(config)
# prompt encoder
self.prompt_encoder = None
# plm embedding layer
self.backbone_embeddings = self.bert.embeddings.word_embeddings
# prompt embedding layer
self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size)
self.init_weights()
# These attributes should be assigned once the model is initialized
self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
# For regression
self.lb = None
self.ub = None
# For label search.
self.return_full_softmax = None
def freeze_backbone(self, use_freezing: bool=True):
if use_freezing:
self.bert = freezer.freeze_lm(self.bert)
else:
self.bert = freezer.unfreeze_lm(self.bert)
def generate_continuous_prompt_inputs(self, input_ids, block_flag=None, reparameterization=False):
"""
Generate continuous prompt embedding
"""
inputs_embeds = self.backbone_embeddings(input_ids)
batch_size = inputs_embeds.shape[0]
if block_flag is None:
# the first token is set 1, others are set 0
block_flag = torch.zeros_like(input_ids).long().to(inputs_embeds.device)
block_flag[:, 0] = 1
try:
replace_embeds = self.prompt_embeddings(
torch.LongTensor(list(range(self.pre_seq_len))).to(inputs_embeds.device))
except:
import pdb
pdb.set_trace()
replace_embeds = self.prompt_embeddings(
torch.LongTensor(list(range(self.pre_seq_len))))
replace_embeds = replace_embeds.unsqueeze(0) # [batch_size, prompt_length, embed_size]
if self.prompt_encoder is not None:
replace_embeds = self.prompt_encoder(replace_embeds)
# edit by wjn
if reparameterization:
# blocked_indices = (block_flag == 1).nonzero(as_tuple=False).reshape((batch_size, self.pre_seq_len, 2))[:, :, 1]
blocked_indices = (block_flag == 1).nonzero()
# reparameterization
for bidx in range(batch_size):
for i in range(blocked_indices.shape[1]):
inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[:, i, :].squeeze()
else:
replace_embeds = replace_embeds.expand(batch_size, self.pre_seq_len, -1).to(inputs_embeds.device)
inputs_embeds = torch.cat((replace_embeds, inputs_embeds), dim=1)
return inputs_embeds
def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
"""
Encoding and obtain logits at masked position
"""
batch_size = inputs_embeds.shape[0]
if mask_pos is not None:
mask_pos = mask_pos.squeeze()
# Encode everything
if inputs_embeds is None:
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
else:
if inputs_embeds.shape[1] == attention_mask.shape[1]:
outputs = self.bert(
None,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds
)
# Get <mask> token representation
sequence_output, pooled_output = outputs[:2]
# sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
else:
if attention_mask is not None:
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).long().to(self.bert.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
if token_type_ids is not None:
prefix_token_type_ids = torch.zeros(batch_size, self.pre_seq_len).long().to(self.bert.device)
token_type_ids = torch.cat((prefix_token_type_ids, token_type_ids), dim=1)
outputs = self.bert(
None,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds
)
# Get <mask> token representation
sequence_output, pooled_output = outputs[:2]
sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# Logits over vocabulary tokens
prediction_mask_scores = self.cls(sequence_mask_output)
# Exit early and only return mask logits.
if return_full_softmax:
return prediction_mask_scores
# Return logits for each label
logits = []
for label_id in range(len(self.label_word_list)):
logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
logits = torch.cat(logits, -1)
# Regression task
if self.config.num_labels == 1:
logsoftmax = nn.LogSoftmax(-1)
logits = logsoftmax(logits) # Log prob of right polarity
return logits, sequence_mask_output
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
mask_pos=None,
labels=None,
inputs_embeds=None,
block_flag=None,
return_dict=None,
):
inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag)
logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
loss = None
if labels is not None:
if self.num_labels == 1:
# Regression task
loss_fct = nn.KLDivLoss(log_target=True)
labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
loss = loss_fct(logits.view(-1, 2), labels)
else:
if labels.shape == logits.shape:
loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
labels, reduction="batchmean")
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
output = (logits,)
if self.num_labels == 1:
# Regression output
output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
if not return_dict:
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
"""
Prefix-tuning BERT
"""
class PromptBertPrefixForSequenceClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.pre_seq_len = self.config.pre_seq_len
self.hidden_size = self.config.hidden_size
self.n_layer = config.num_hidden_layers
self.n_head = config.num_attention_heads
self.n_embd = config.hidden_size // config.num_attention_heads
# backbone
self.bert = BertModel(config)
if self.config.use_freezing:
self.bert = freezer.freeze_lm(self.bert)
# mlm head
self.cls = BertOnlyMLMHead(config)
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
# plm embedding layer
self.backbone_embeddings = self.bert.embeddings.word_embeddings
# prompt embedding layer
self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size)
# prefix encoder
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
self.prefix_encoder = PrefixEncoder(config)
self.init_weights()
# These attributes should be assigned once the model is initialized
self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
# For regression
self.lb = None
self.ub = None
# For label search.
self.return_full_softmax = None
# For regression
self.lb = None
self.ub = None
# For label search.
self.return_full_softmax = None
def freeze_backbone(self, use_freezing: bool=True):
if use_freezing:
self.bert = freezer.freeze_lm(self.bert)
else:
self.bert = freezer.unfreeze_lm(self.bert)
def get_prompt(self, batch_size):
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
past_key_values = self.prefix_encoder(prefix_tokens)
# bsz, seqlen, _ = past_key_values.shape
past_key_values = past_key_values.view(
batch_size,
self.pre_seq_len,
self.n_layer * 2,
self.n_head,
self.n_embd
)
past_key_values = self.dropout(past_key_values)
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
return past_key_values
def embed_encode(self, input_ids):
embedding_output = self.bert.embeddings.word_embeddings(input_ids)
return embedding_output
def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
batch_size = input_ids.size(0)
# add prefix for prompt-tuning
past_key_values = self.get_prompt(batch_size=batch_size)
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
if mask_pos is not None:
mask_pos = mask_pos.squeeze()
# Encode everything
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
past_key_values=past_key_values,
)
# Get <mask> token representation
sequence_output, pooled_output = outputs[:2]
# sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# Logits over vocabulary tokens
prediction_mask_scores = self.cls(sequence_mask_output)
# Exit early and only return mask logits.
if return_full_softmax:
return prediction_mask_scores
# print("prediction_mask_scores.shape=", prediction_mask_scores.shape) # [batch_size, seq_len, vocab_size]
# Return logits for each label
logits = []
for label_id in range(len(self.label_word_list)):
logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
logits = torch.cat(logits, -1)
# Regression task
if self.config.num_labels == 1:
logsoftmax = nn.LogSoftmax(-1)
logits = logsoftmax(logits) # Log prob of right polarity
return logits, sequence_mask_output
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
mask_pos=None,
labels=None,
inputs_embeds=None,
block_flag=None,
return_dict=None,
):
logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
loss = None
if labels is not None:
if self.num_labels == 1:
# Regression task
loss_fct = nn.KLDivLoss(log_target=True)
labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
loss = loss_fct(logits.view(-1, 2), labels)
else:
if labels.shape == logits.shape:
loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
labels, reduction="batchmean")
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
output = (logits,)
if self.num_labels == 1:
# Regression output
output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
if not return_dict:
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
"""
Adapter-tuning BERT
"""
class PromptBertAdapterForSequenceClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertAdaModel(config)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
if self.config.use_freezing:
self.bert = freezer.freeze_lm_component(self.bert, "adapter")
# These attributes should be assigned once the model is initialized
self.model_args = None
self.data_args = None
self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
# For regression
self.lb = None
self.ub = None
# For label search.
self.return_full_softmax = None
def freeze_backbone(self, use_freezing: bool=True):
if use_freezing:
self.bert = freezer.freeze_lm_component(self.bert, "adapter")
else:
self.bert = freezer.unfreeze_lm(self.bert)
def embed_encode(self, input_ids):
embedding_output = self.bert.embeddings.word_embeddings(input_ids)
return embedding_output
def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
batch_size = input_ids.size(0)
if mask_pos is not None:
mask_pos = mask_pos.squeeze()
# Encode everything
if inputs_embeds is None:
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
else:
outputs = self.bert(
None,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds
)
# Get <mask> token representation
sequence_output, pooled_output = outputs[:2]
sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# Logits over vocabulary tokens
prediction_mask_scores = self.cls(sequence_mask_output)
# Exit early and only return mask logits.
if return_full_softmax:
return prediction_mask_scores
# Return logits for each label
logits = []
for label_id in range(len(self.label_word_list)):
logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
logits = torch.cat(logits, -1)
# Regression task
if self.config.num_labels == 1:
logsoftmax = nn.LogSoftmax(-1)
logits = logsoftmax(logits) # Log prob of right polarity
return logits, sequence_mask_output
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
mask_pos=None,
labels=None,
inputs_embeds=None,
block_flag=None,
return_dict=None,
):
logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
loss = None
if labels is not None:
if self.num_labels == 1:
# Regression task
loss_fct = nn.KLDivLoss(log_target=True)
labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
loss = loss_fct(logits.view(-1, 2), labels)
else:
if labels.shape == logits.shape:
loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
labels, reduction="batchmean")
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
output = (logits,)
if self.num_labels == 1:
# Regression output
output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
if not return_dict:
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
"""
Vanilla Prompt-tuning RoBERTa
"""
class PromptRobertaForSequenceClassification(RobertaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.pre_seq_len = self.config.pre_seq_len
self.hidden_size = self.config.hidden_size
# backbone
self.roberta = RobertaModel(config)
if self.config.use_freezing:
self.roberta = freezer.freeze_lm(self.roberta)
# mlm head
self.cls = RobertaLMHead(config)
self.init_weights()
# These attributes should be assigned once the model is initialized
self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.roberta.device)
# For regression
self.lb = None
self.ub = None
# For label search.
self.return_full_softmax = None
def freeze_backbone(self, use_freezing: bool=True):
if use_freezing:
self.roberta = freezer.freeze_lm(self.roberta)
else:
self.roberta = freezer.unfreeze_lm(self.roberta)
def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
"""
Encoding and obtain logits at masked position
"""
if mask_pos is not None:
mask_pos = mask_pos.squeeze()
# Encode everything
if inputs_embeds is None:
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
else:
outputs = self.roberta(
None,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds
)
# Get <mask> token representation
sequence_output, pooled_output = outputs[:2]
sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# Logits over vocabulary tokens
prediction_mask_scores = self.cls(sequence_mask_output)
# Exit early and only return mask logits.
if return_full_softmax:
return prediction_mask_scores
# Return logits for each label
logits = []
for label_id in range(len(self.label_word_list)):
logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
logits = torch.cat(logits, -1)
# Regression task
if self.config.num_labels == 1:
logsoftmax = nn.LogSoftmax(-1)
logits = logsoftmax(logits) # Log prob of right polarity
return logits, sequence_mask_output
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
mask_pos=None,
labels=None,
inputs_embeds=None,
block_flag=None,
return_dict=None,
):
logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
loss = None
if labels is not None:
if self.num_labels == 1:
# Regression task
loss_fct = nn.KLDivLoss(log_target=True)
labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
loss = loss_fct(logits.view(-1, 2), labels)
else:
if labels.shape == logits.shape:
loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
labels, reduction="batchmean")
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
output = (logits,)
if self.num_labels == 1:
# Regression output
output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
if not return_dict:
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
"""
P-tuning RoBERTa
"""
class PromptRobertaPtuningForSequenceClassification(RobertaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.pre_seq_len = self.config.pre_seq_len
self.hidden_size = self.config.hidden_size
# backbone
self.roberta = RobertaModel(config)
if self.config.use_freezing:
self.roberta = freezer.freeze_lm(self.roberta)
# mlm head
self.cls = RobertaLMHead(config)
# prompt encoder
self.prompt_encoder = None
# plm embedding layer
self.backbone_embeddings = self.roberta.embeddings.word_embeddings
# prompt embedding layer
self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size)
self.init_weights()
# These attributes should be assigned once the model is initialized
self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.roberta.device)
# For regression
self.lb = None
self.ub = None
# For label search.
self.return_full_softmax = None
def freeze_backbone(self, use_freezing: bool=True):
if use_freezing:
self.roberta = freezer.freeze_lm(self.roberta)
else:
self.roberta = freezer.unfreeze_lm(self.roberta)
def generate_continuous_prompt_inputs(self, input_ids, block_flag=None, reparameterization=False):
"""
Generate continuous prompt embedding
"""
inputs_embeds = self.backbone_embeddings(input_ids)
batch_size = inputs_embeds.shape[0]
if block_flag is None:
# the first token is set 1, others are set 0
block_flag = torch.zeros_like(input_ids).long().to(inputs_embeds.device)
block_flag[:, 0] = 1
try:
replace_embeds = self.prompt_embeddings(
torch.LongTensor(list(range(self.pre_seq_len))).to(inputs_embeds.device))
except:
import pdb
pdb.set_trace()
replace_embeds = self.prompt_embeddings(torch.LongTensor(list(range(self.pre_seq_len))))
replace_embeds = replace_embeds.unsqueeze(0) # [batch_size, prompt_length, embed_size]
if self.prompt_encoder is not None:
replace_embeds = self.prompt_encoder(replace_embeds)
# edit by wjn
if reparameterization:
# blocked_indices = (block_flag == 1).nonzero(as_tuple=False).reshape((batch_size, self.pre_seq_len, 2))[:, :, 1]
blocked_indices = (block_flag == 1).nonzero()
# reparameterization
for bidx in range(batch_size):
for i in range(blocked_indices.shape[1]):
inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[:, i, :].squeeze()
else:
replace_embeds = replace_embeds.expand(batch_size, self.pre_seq_len, -1).to(inputs_embeds.device)
inputs_embeds = torch.cat((replace_embeds, inputs_embeds), dim=1)
return inputs_embeds
def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
"""
Encoding and obtain logits at masked position
"""
batch_size = inputs_embeds.shape[0]
if mask_pos is not None:
mask_pos = mask_pos.squeeze()
# Encode everything
if inputs_embeds is None:
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
else:
if inputs_embeds.shape[1] == attention_mask.shape[1]:
outputs = self.roberta(
None,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds
)
# Get <mask> token representation
sequence_output, pooled_output = outputs[:2]
# sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
else:
if attention_mask is not None:
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).long().to(self.roberta.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
if token_type_ids is not None:
prefix_token_type_ids = torch.zeros(batch_size, self.pre_seq_len).long().to(self.roberta.device)
token_type_ids = torch.cat((prefix_token_type_ids, token_type_ids), dim=1)
outputs = self.roberta(
None,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds
)
# Get <mask> token representation
sequence_output, pooled_output = outputs[:2]
sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# Logits over vocabulary tokens
prediction_mask_scores = self.cls(sequence_mask_output)
# Exit early and only return mask logits.
if return_full_softmax:
return prediction_mask_scores
# Return logits for each label
logits = []
for label_id in range(len(self.label_word_list)):
logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
logits = torch.cat(logits, -1)
# Regression task
if self.config.num_labels == 1:
logsoftmax = nn.LogSoftmax(-1)
logits = logsoftmax(logits) # Log prob of right polarity
return logits, sequence_mask_output
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
mask_pos=None,
labels=None,
inputs_embeds=None,
block_flag=None,
return_dict=None,
):
inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag)
logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
loss = None
if labels is not None:
if self.num_labels == 1:
# Regression task
loss_fct = nn.KLDivLoss(log_target=True)
labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
loss = loss_fct(logits.view(-1, 2), labels)
else:
if labels.shape == logits.shape:
loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
labels, reduction="batchmean")
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
output = (logits,)
if self.num_labels == 1:
# Regression output
output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
if not return_dict:
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
"""
Prefix-tuning RoBERTa
"""
class PromptRobertaPrefixForSequenceClassification(RobertaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.pre_seq_len = self.config.pre_seq_len
self.hidden_size = self.config.hidden_size
self.n_layer = config.num_hidden_layers
self.n_head = config.num_attention_heads
self.n_embd = config.hidden_size // config.num_attention_heads
# backbone
self.robert = RobertaModel(config)
if self.config.use_freezing:
self.robert = freezer.freeze_lm(self.robert)
# mlm head
self.cls = RobertaLMHead(config)
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
# plm embedding layer
self.backbone_embeddings = self.robert.embeddings.word_embeddings
# prompt embedding layer
self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size)
# prefix encoder
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
self.prefix_encoder = PrefixEncoder(config)
self.init_weights()
# These attributes should be assigned once the model is initialized
self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.robert.device)
# For regression
self.lb = None
self.ub = None
# For label search.
self.return_full_softmax = None
# For regression
self.lb = None
self.ub = None
# For label search.
self.return_full_softmax = None
def freeze_backbone(self, use_freezing: bool=True):
if use_freezing:
self.robert = freezer.freeze_lm(self.robert)
else:
self.robert = freezer.unfreeze_lm(self.robert)
def get_prompt(self, batch_size):
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.robert.device)
past_key_values = self.prefix_encoder(prefix_tokens)
# bsz, seqlen, _ = past_key_values.shape
past_key_values = past_key_values.view(
batch_size,
self.pre_seq_len,
self.n_layer * 2,
self.n_head,
self.n_embd
)
past_key_values = self.dropout(past_key_values)
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
return past_key_values
def embed_encode(self, input_ids):
embedding_output = self.robert.embeddings.word_embeddings(input_ids)
return embedding_output
def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
batch_size = input_ids.size(0)
# add prefix for prompt-tuning
past_key_values = self.get_prompt(batch_size=batch_size)
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.robert.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
if mask_pos is not None:
mask_pos = mask_pos.squeeze()
# Encode everything
outputs = self.robert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
past_key_values=past_key_values,
)
# Get <mask> token representation
sequence_output, pooled_output = outputs[:2]
# sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# Logits over vocabulary tokens
prediction_mask_scores = self.cls(sequence_mask_output)
# Exit early and only return mask logits.
if return_full_softmax:
return prediction_mask_scores
# Return logits for each label
logits = []
for label_id in range(len(self.label_word_list)):
logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
logits = torch.cat(logits, -1)
# Regression task
if self.config.num_labels == 1:
logsoftmax = nn.LogSoftmax(-1)
logits = logsoftmax(logits) # Log prob of right polarity
return logits, sequence_mask_output
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
mask_pos=None,
labels=None,
inputs_embeds=None,
block_flag=None,
return_dict=None,
):
logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
loss = None
if labels is not None:
if self.num_labels == 1:
# Regression task
loss_fct = nn.KLDivLoss(log_target=True)
labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
loss = loss_fct(logits.view(-1, 2), labels)
else:
if labels.shape == logits.shape:
loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
labels, reduction="batchmean")
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
output = (logits,)
if self.num_labels == 1:
# Regression output
output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
if not return_dict:
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
"""
Adapter-tuning RoBERTa
"""
class PromptRobertaAdapterForSequenceClassification(RobertaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.roberta = RobertaAdaModel(config)
self.cls = RobertaLMHead(config)
self.init_weights()
if self.config.use_freezing:
self.roberta = freezer.freeze_lm_component(self.roberta, "adapter")
# These attributes should be assigned once the model is initialized
self.model_args = None
self.data_args = None
self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.roberta.device)
# For regression
self.lb = None
self.ub = None
# For label search.
self.return_full_softmax = None
def freeze_backbone(self, use_freezing: bool=True):
if use_freezing:
self.roberta = freezer.freeze_lm_component(self.roberta, "adapter")
else:
self.roberta = freezer.unfreeze_lm(self.berobertart)
def embed_encode(self, input_ids):
embedding_output = self.roberta.embeddings.word_embeddings(input_ids)
return embedding_output
def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
batch_size = input_ids.size(0)
if mask_pos is not None:
mask_pos = mask_pos.squeeze()
# Encode everything
if inputs_embeds is None:
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
else:
outputs = self.roberta(
None,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds
)
# Get <mask> token representation
sequence_output, pooled_output = outputs[:2]
sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# Logits over vocabulary tokens
prediction_mask_scores = self.cls(sequence_mask_output)
# Exit early and only return mask logits.
if return_full_softmax:
return prediction_mask_scores
# Return logits for each label
logits = []
for label_id in range(len(self.label_word_list)):
logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
logits = torch.cat(logits, -1)
# Regression task
if self.config.num_labels == 1:
logsoftmax = nn.LogSoftmax(-1)
logits = logsoftmax(logits) # Log prob of right polarity
return logits, sequence_mask_output
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
mask_pos=None,
labels=None,
inputs_embeds=None,
block_flag=None,
return_dict=None,
):
logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
loss = None
if labels is not None:
if self.num_labels == 1:
# Regression task
loss_fct = nn.KLDivLoss(log_target=True)
labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
loss = loss_fct(logits.view(-1, 2), labels)
else:
if labels.shape == logits.shape:
loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
labels, reduction="batchmean")
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
output = (logits,)
if self.num_labels == 1:
# Regression output
output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
if not return_dict:
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
# class DebertaForPromptFinetuning(DebertaPreTrainedModel):
# _keys_to_ignore_on_load_unexpected = [r"pooler"]
# _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
# def __init__(self, config):
# super().__init__(config)
# self.num_labels = config.num_labels
# #self.deberta = DebertaV2Model(config)
# self.deberta = DebertaModel(config)
# self.cls = DebertaOnlyMLMHead(config)
# if self.config.use_freezing:
# self.deberta = freezer.freeze_lm(self.deberta)
# self.pooler = ContextPooler(config)
# output_dim = self.pooler.output_dim
# self.classifier = torch.nn.Linear(output_dim, self.num_labels)
# drop_out = getattr(config, "cls_dropout", None)
# drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
# self.dropout = StableDropout(drop_out)
# classification_list = [self.pooler, self.dropout,self.classifier]
# self.classifier = nn.Sequential(*classification_list)
# # self.cls = DebertaV2OnlyMLMHead(config)
# self.map = nn.Linear(config.hidden_size, config.hidden_size)
# self.init_weights()
# # These attributes should be assigned once the model is initialized
# self.model_args = None
# self.data_args = None
# self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
# self.K = 1
# self.step_size=1e-5
# # import pdb
# # pdb.set_trace()
# #self.step_size=config.step_size
# # For regression
# self.lb = None
# self.ub = None
# self.pre_seq_len = self.config.pre_seq_len
# # For auto label search.
# self.return_full_softmax = None
# def freeze_backbone(self, use_freezing: bool=True):
# if use_freezing:
# self.deberta = freezer.freeze_lm(self.deberta)
# else:
# self.deberta = freezer.unfreeze_lm(self.deberta)
# def embed_encode(self, input_ids):
# embedding_output = self.deberta.embeddings.word_embeddings(input_ids)
# return embedding_output
# def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None,
# return_full_softmax=False):
# batch_size = input_ids.size(0)
# if mask_pos is not None:
# mask_pos = mask_pos.squeeze()
# # Encode everything
# if inputs_embeds is None:
# outputs = self.deberta(
# input_ids,
# attention_mask=attention_mask,
# token_type_ids=token_type_ids
# )
# else:
# outputs = self.deberta(
# None,
# attention_mask=attention_mask,
# token_type_ids=token_type_ids,
# inputs_embeds=inputs_embeds
# )
# # Get <mask> token representation
# sequence_output = outputs[0]
# sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
# sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# # Logits over vocabulary tokens
# prediction_mask_scores = self.cls(sequence_mask_output)
# # sequence_mask_output = self.lm_head.dense(sequence_mask_output)
# # Exit early and only return mask logits.
# if return_full_softmax:
# return prediction_mask_scores
# # Return logits for each label
# logits = []
# for label_id in range(len(self.label_word_list)):
# logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
# logits = torch.cat(logits, -1)
# # Regression task
# if self.config.num_labels == 1:
# logsoftmax = nn.LogSoftmax(-1)
# logits = logsoftmax(logits) # Log prob of right polarity
# if self.model_args.hybrid == 1:
# cls_logits = self.classifier(sequence_output)
# return (logits, cls_logits), sequence_mask_output
# return logits, sequence_mask_output
# def forward(
# self,
# input_ids=None,
# attention_mask=None,
# token_type_ids=None,
# mask_pos=None,
# labels=None,
# inputs_embeds=None,
# fwd_type=0,
# block_flag=None
# ):
# if fwd_type == 2:
# assert inputs_embeds is not None
# return self.encode(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
# mask_pos=mask_pos, inputs_embeds=inputs_embeds)
# elif fwd_type == 1:
# return self.embed_encode(input_ids)
# if (self.model_args.prompt_ptuning or self.model_args.prompt_prefix) and block_flag is not None:
# inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag)
# logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
# if self.model_args.hybrid == 1:
# logits = logits[0]
# cls_logits = logits[1]
# loss = None
# if labels is not None:
# if self.num_labels == 1:
# # Regression task
# loss_fct = nn.KLDivLoss(log_target=True)
# labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb),
# (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
# loss = loss_fct(logits.view(-1, 2), labels)
# else:
# if labels.shape == logits.shape:
# loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
# labels, reduction="batchmean")
# else:
# loss_fct = nn.CrossEntropyLoss()
# loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
# output = (logits,)
# if self.num_labels == 1:
# # Regression output
# output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
# return ((loss,) + output) if loss is not None else output
# # add by wjn
# # Prefix-tuning for Deberta
# class DebertaPrefixForPromptFinetuning(DebertaPreTrainedModel):
# def __init__(self, config):
# super().__init__(config)
# self.num_labels = config.num_labels
# #self.deberta = DebertaV2Model(config)
# self.deberta = DebertaModel(config)
# self.cls = DebertaOnlyMLMHead(config)
# self.pooler = ContextPooler(config)
# output_dim = self.pooler.output_dim
# self.classifier = torch.nn.Linear(output_dim, self.num_labels)
# drop_out = getattr(config, "cls_dropout", None)
# drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
# self.dropout = StableDropout(drop_out)
# classification_list = [self.pooler, self.dropout,self.classifier]
# self.classifier = nn.Sequential(*classification_list)
# # self.cls = DebertaV2OnlyMLMHead(config)
# self.map = nn.Linear(config.hidden_size, config.hidden_size)
# self.init_weights()
# if self.config.use_freezing:
# self.deberta = freezer.freeze_lm(self.deberta)
# self.pre_seq_len = config.pre_seq_len
# self.n_layer = config.num_hidden_layers
# self.n_head = config.num_attention_heads
# self.n_embd = config.hidden_size // config.num_attention_heads
# self.prefix_tokens = torch.arange(self.pre_seq_len).long()
# self.prefix_encoder = PrefixEncoder(config)
# # These attributes should be assigned once the model is initialized
# self.model_args = None
# self.data_args = None
# self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
# self.K = 1
# self.step_size=1e-5
# # import pdb
# # pdb.set_trace()
# #self.step_size=config.step_size
# # For regression
# self.lb = None
# self.ub = None
# # For auto label search.
# self.return_full_softmax = None
# def freeze_backbone(self, use_freezing: bool=True):
# if use_freezing:
# self.deberta = freezer.freeze_lm(self.deberta)
# else:
# self.deberta = freezer.unfreeze_lm(self.deberta)
# def get_prompt(self, batch_size):
# prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device)
# past_key_values = self.prefix_encoder(prefix_tokens)
# # bsz, seqlen, _ = past_key_values.shape
# past_key_values = past_key_values.view(
# batch_size,
# self.pre_seq_len,
# self.n_layer * 2,
# self.n_head,
# self.n_embd
# )
# past_key_values = self.dropout(past_key_values)
# past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
# return past_key_values
# def get_constrast_loss(self,
# input_ids=None,
# attention_mask=None,
# mask_pos=None,
# labels=None,
# inputs_embeds=None):
# self.cos = nn.CosineSimilarity(dim=-1)
# _, sequence_mask_output_1 = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds)
# _, sequence_mask_output_2 = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds)
# sequence_mask_output_1= self.lm_head.dense(sequence_mask_output_1)
# sequence_mask_output_2 = self.lm_head.dense(sequence_mask_output_2)
# # input_args = [input_ids, attention_mask, mask_pos, labels, None, 1]
# # embed = self.forward(*input_args)
# #
# # vat_args = [input_ids, attention_mask, mask_pos, labels, embed, 2]
# #
# # adv_logits, outputs = self.forward(*vat_args)
# #
# # logit_mask = F.softmax(logits, dim=-1)[torch.arange(adv_logits.size(0)), labels] > 0.7
# #
# # outputs = outputs[logit_mask]
# # seq_outputs = sequence_mask_output[logit_mask]
# # new_label = labels[logit_mask]
# # #
# # #
# # rand_perm = torch.randperm(outputs.size(0))
# # rand_outputs = outputs[rand_perm, :]
# # rand_label = new_label[rand_perm]
# # pair_label = (new_label == rand_label).long()
# #
# # seq_outputs = self.map(seq_outputs)
# # rand_outputs = self.map(rand_outputs)
# pair_labels = (labels.unsqueeze(1) == labels.unsqueeze(0)).float()
# # import pdb
# # pdb.set_trace()
# contra_loss = self.contra_lc(sequence_mask_output_1.unsqueeze(1), sequence_mask_output_2.unsqueeze(0), pair_labels)
# if torch.isnan(contra_loss):
# return 0
# return contra_loss
# def embed_encode(self, input_ids):
# embedding_output = self.deberta.embeddings.word_embeddings(input_ids)
# return embedding_output
# def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
# batch_size = input_ids.size(0)
# # add prefix for prompt-tuning
# past_key_values = self.get_prompt(batch_size=batch_size)
# prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device)
# attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
# if mask_pos is not None:
# mask_pos = mask_pos.squeeze()
# # Encode everything
# outputs = self.deberta(
# input_ids,
# attention_mask=attention_mask,
# token_type_ids=token_type_ids,
# past_key_values=past_key_values,
# )
# # Get <mask> token representation
# sequence_output, pooled_output = outputs[:2]
# # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
# sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# # Logits over vocabulary tokens
# prediction_mask_scores = self.cls(sequence_mask_output)
# #sequence_mask_output = self.lm_head.dense(sequence_mask_output)
# # Exit early and only return mask logits.
# if return_full_softmax:
# return prediction_mask_scores
# # Return logits for each label
# logits = []
# for label_id in range(len(self.label_word_list)):
# logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
# logits = torch.cat(logits, -1)
# # Regression task
# if self.config.num_labels == 1:
# logsoftmax = nn.LogSoftmax(-1)
# logits = logsoftmax(logits) # Log prob of right polarity
# if self.model_args.hybrid == 1:
# cls_logits = self.classifier(sequence_output)
# return (logits, cls_logits), sequence_mask_output
# return logits, sequence_mask_output
# def forward(
# self,
# input_ids=None,
# attention_mask=None,
# token_type_ids=None,
# mask_pos=None,
# labels=None,
# inputs_embeds=None,
# fwd_type=0,
# block_flag=None,
# return_dict=None,
# ):
# if fwd_type == 2:
# assert inputs_embeds is not None
# return self.encode(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
# mask_pos=mask_pos, inputs_embeds=inputs_embeds)
# elif fwd_type == 1:
# return self.embed_encode(input_ids)
# if (self.model_args.prompt_ptuning or self.model_args.prompt_prefix) and block_flag is not None:
# inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag)
# logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
# if self.model_args.hybrid == 1:
# logits = logits[0]
# cls_logits = logits[1]
# loss = None
# if labels is not None:
# if self.num_labels == 1:
# # Regression task
# loss_fct = nn.KLDivLoss(log_target=True)
# labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb),
# (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
# loss = loss_fct(logits.view(-1, 2), labels)
# else:
# if labels.shape == logits.shape:
# loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
# labels, reduction="batchmean")
# else:
# loss_fct = nn.CrossEntropyLoss()
# loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
# output = (logits,)
# if self.num_labels == 1:
# # Regression output
# output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
# if not return_dict:
# return ((loss,) + output) if loss is not None else output
# return SequenceClassifierOutput(
# loss=loss,
# logits=logits,
# )
# class Debertav2ForPromptFinetuning(DebertaV2PreTrainedModel):
# _keys_to_ignore_on_load_unexpected = [r"pooler"]
# _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
# def __init__(self, config):
# super().__init__(config)
# self.num_labels = config.num_labels
# self.deberta = DebertaV2Model(config)
# if self.config.use_freezing:
# self.deberta = freezer.freeze_lm(self.deberta)
# self.cls = DebertaV2OnlyMLMHead(config)
# #self.deberta = DebertaModel(config)
# #self.cls = DebertaOnlyMLMHead(config)
# self.pooler = ContextPooler(config)
# output_dim = self.pooler.output_dim
# self.classifier = torch.nn.Linear(output_dim, self.num_labels)
# drop_out = getattr(config, "cls_dropout", None)
# drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
# self.dropout = StableDropout(drop_out)
# classification_list = [self.pooler, self.dropout,self.classifier]
# self.classifier = nn.Sequential(*classification_list)
# # self.cls = DebertaV2OnlyMLMHead(config)
# self.map = nn.Linear(config.hidden_size, config.hidden_size)
# self.init_weights()
# # These attributes should be assigned once the model is initialized
# self.model_args = None
# self.data_args = None
# self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
# self.K = 1
# self.step_size=1e-5
# # import pdb
# # pdb.set_trace()
# #self.step_size=config.step_size
# # For regression
# self.lb = None
# self.ub = None
# self.pre_seq_len = self.config.pre_seq_len
# # For auto label search.
# self.return_full_softmax = None
# def freeze_backbone(self, use_freezing: bool=True):
# if use_freezing:
# self.deberta = freezer.freeze_lm(self.deberta)
# else:
# self.deberta = freezer.unfreeze_lm(self.deberta)
# def embed_encode(self, input_ids):
# embedding_output = self.deberta.embeddings.word_embeddings(input_ids)
# return embedding_output
# def encode(self, input_ids=None, attention_mask=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
# batch_size = input_ids.size(0)
# if mask_pos is not None:
# mask_pos = mask_pos.squeeze()
# # Encode everything
# if inputs_embeds is None:
# outputs = self.deberta(
# input_ids,
# attention_mask=attention_mask
# )
# else:
# outputs = self.deberta(
# None,
# attention_mask=attention_mask,
# inputs_embeds=inputs_embeds
# )
# # Get <mask> token representation
# sequence_output = outputs[0]
# sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
# sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# # Logits over vocabulary tokens
# prediction_mask_scores = self.cls(sequence_mask_output)
# #sequence_mask_output = self.lm_head.dense(sequence_mask_output)
# # Exit early and only return mask logits.
# if return_full_softmax:
# return prediction_mask_scores
# # Return logits for each label
# logits = []
# for label_id in range(len(self.label_word_list)):
# logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
# logits = torch.cat(logits, -1)
# # Regression task
# if self.config.num_labels == 1:
# logsoftmax = nn.LogSoftmax(-1)
# logits = logsoftmax(logits) # Log prob of right polarity
# return logits, sequence_mask_output
# def forward(
# self,
# input_ids=None,
# attention_mask=None,
# mask_pos=None,
# labels=None,
# inputs_embeds=None,
# fwd_type=0,
# block_flag=None,
# return_dict=None
# ):
# if fwd_type == 2:
# assert inputs_embeds is not None
# return self.encode(input_ids=input_ids, attention_mask=attention_mask, mask_pos=mask_pos, inputs_embeds=inputs_embeds)
# elif fwd_type == 1:
# return self.embed_encode(input_ids)
# logits, sequence_mask_output = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds)
# loss = None
# if labels is not None:
# if self.num_labels == 1:
# # Regression task
# loss_fct = nn.KLDivLoss(log_target=True)
# labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
# loss = loss_fct(logits.view(-1, 2), labels)
# else:
# if labels.shape == logits.shape:
# loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
# labels, reduction="batchmean")
# else:
# loss_fct = nn.CrossEntropyLoss()
# loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
# if self.model_args.hybrid == 1:
# cls_loss = loss_fct(cls_logits.view(-1, cls_logits.size(-1)), labels.view(-1))
# loss = loss + cls_loss
# output = (logits,)
# if self.num_labels == 1:
# # Regression output
# output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
# if not return_dict:
# return ((loss,) + output) if loss is not None else output
# return SequenceClassifierOutput(
# loss=loss,
# logits=logits,
# )
# class Debertav2PrefixForPromptFinetuning(DebertaV2PreTrainedModel):
# _keys_to_ignore_on_load_unexpected = [r"pooler"]
# _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
# def __init__(self, config):
# super().__init__(config)
# self.num_labels = config.num_labels
# self.deberta = DebertaV2Model(config)
# self.cls = DebertaV2OnlyMLMHead(config)
# #self.deberta = DebertaModel(config)
# #self.cls = DebertaOnlyMLMHead(config)
# self.pooler = ContextPooler(config)
# output_dim = self.pooler.output_dim
# self.classifier = torch.nn.Linear(output_dim, self.num_labels)
# drop_out = getattr(config, "cls_dropout", None)
# drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
# self.dropout = StableDropout(drop_out)
# classification_list = [self.pooler, self.dropout,self.classifier]
# self.classifier = nn.Sequential(*classification_list)
# # self.cls = DebertaV2OnlyMLMHead(config)
# self.map = nn.Linear(config.hidden_size, config.hidden_size)
# self.init_weights()
# if self.config.use_freezing:
# self.deberta = freezer.freeze_lm(self.deberta)
# self.pre_seq_len = config.pre_seq_len
# self.n_layer = config.num_hidden_layers
# self.n_head = config.num_attention_heads
# self.n_embd = config.hidden_size // config.num_attention_heads
# self.prefix_tokens = torch.arange(self.pre_seq_len).long()
# self.prefix_encoder = PrefixEncoder(config)
# # These attributes should be assigned once the model is initialized
# self.model_args = None
# self.data_args = None
# self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
# self.K = 1
# self.step_size=1e-5
# # import pdb
# # pdb.set_trace()
# #self.step_size=config.step_size
# # For regression
# self.lb = None
# self.ub = None
# # For auto label search.
# self.return_full_softmax = None
# def freeze_backbone(self, use_freezing: bool=True):
# if use_freezing:
# self.deberta = freezer.freeze_lm(self.deberta)
# else:
# self.deberta = freezer.unfreeze_lm(self.deberta)
# def get_prompt(self, batch_size):
# prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device)
# past_key_values = self.prefix_encoder(prefix_tokens)
# # bsz, seqlen, _ = past_key_values.shape
# past_key_values = past_key_values.view(
# batch_size,
# self.pre_seq_len,
# self.n_layer * 2,
# self.n_head,
# self.n_embd
# )
# past_key_values = self.dropout(past_key_values)
# past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
# return past_key_values
# def embed_encode(self, input_ids):
# embedding_output = self.deberta.embeddings.word_embeddings(input_ids)
# return embedding_output
# def encode(self, input_ids=None, attention_mask=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
# batch_size = input_ids.size(0)
# # add prefix for prompt-tuning
# past_key_values = self.get_prompt(batch_size=batch_size)
# prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device)
# attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
# if mask_pos is not None:
# mask_pos = mask_pos.squeeze()
# # Encode everything
# outputs = self.deberta(
# input_ids,
# attention_mask=attention_mask,
# past_key_values=past_key_values,
# )
# # Get <mask> token representation
# sequence_output = outputs[0]
# # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
# sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# # Logits over vocabulary tokens
# prediction_mask_scores = self.cls(sequence_mask_output)
# #sequence_mask_output = self.lm_head.dense(sequence_mask_output)
# # Exit early and only return mask logits.
# if return_full_softmax:
# return prediction_mask_scores
# # Return logits for each label
# logits = []
# for label_id in range(len(self.label_word_list)):
# logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
# logits = torch.cat(logits, -1)
# # Regression task
# if self.config.num_labels == 1:
# logsoftmax = nn.LogSoftmax(-1)
# logits = logsoftmax(logits) # Log prob of right polarity
# return logits, sequence_mask_output
# def forward(
# self,
# input_ids=None,
# attention_mask=None,
# mask_pos=None,
# labels=None,
# inputs_embeds=None,
# fwd_type=0,
# block_flag=None,
# return_dict=None,
# ):
# if fwd_type == 2:
# assert inputs_embeds is not None
# return self.encode(input_ids=input_ids, attention_mask=attention_mask, mask_pos=mask_pos, inputs_embeds=inputs_embeds)
# elif fwd_type == 1:
# return self.embed_encode(input_ids)
# logits, sequence_mask_output = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds)
# loss = None
# if labels is not None:
# if self.num_labels == 1:
# # Regression task
# loss_fct = nn.KLDivLoss(log_target=True)
# labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
# loss = loss_fct(logits.view(-1, 2), labels)
# else:
# if labels.shape == logits.shape:
# loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
# labels, reduction="batchmean")
# else:
# loss_fct = nn.CrossEntropyLoss()
# loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
# if self.model_args.hybrid == 1:
# cls_loss = loss_fct(cls_logits.view(-1, cls_logits.size(-1)), labels.view(-1))
# loss = loss + cls_loss
# output = (logits,)
# if self.num_labels == 1:
# # Regression output
# output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
# if not return_dict:
# return ((loss,) + output) if loss is not None else output
# return SequenceClassifierOutput(
# loss=loss,
# logits=logits,
# )