Edit model card
YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Checkpoint model trained from this source code "https://github.com/THUDM/P-tuning-v2" import torch from torch._C import NoopLogger import torch.nn import torch.nn.functional as F from torch import Tensor

from transformers import BertModel, BertPreTrainedModel from transformers import RobertaModel, RobertaPreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutput, Seq2SeqLMOutput

import torch

class PrefixEncoder(torch.nn.Module): r''' The torch.nn model to encode the prefix

Input shape: (batch-size, prefix-length)

Output shape: (batch-size, prefix-length, 2*layers*hidden)
'''
def __init__(self, config):
    super().__init__()
    self.prefix_projection = config.prefix_projection
    # config.num_hidden_layers = 12

    if self.prefix_projection:
        # Use a two-layer MLP to encode the prefix
        self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
        self.trans = torch.nn.Sequential(
            torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
            torch.nn.Tanh(),
            torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
        )
    else:
        self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)

def forward(self, prefix: torch.Tensor):
    if self.prefix_projection:
        prefix_tokens = self.embedding(prefix)
        past_key_values = self.trans(prefix_tokens)
    else:
        past_key_values = self.embedding(prefix)

    # size = 1024*12*2
    # past_key_values = past_key_values[:, :, : int(size)]

    return past_key_values

class BertPrefixForSequenceClassification(BertPreTrainedModel): def init(self, config): super().init(config) self.num_labels = config.num_labels self.config = config self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)

    for param in self.bert.parameters():
        param.requires_grad = False
    
    self.pre_seq_len = config.pre_seq_len
    self.n_layer = config.num_hidden_layers
    self.n_head = config.num_attention_heads
    self.n_embd = config.hidden_size // config.num_attention_heads

    self.prefix_tokens = torch.arange(self.pre_seq_len).long()
    self.prefix_encoder = PrefixEncoder(config)

    bert_param = 0
    for name, param in self.bert.named_parameters():
        bert_param += param.numel()
    all_param = 0
    for name, param in self.named_parameters():
        all_param += param.numel()
    total_param = all_param - bert_param
    print('total param is {}'.format(total_param)) # 9860105

def get_prompt(self, batch_size):
    prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
    past_key_values = self.prefix_encoder(prefix_tokens)
    # bsz, seqlen, _ = past_key_values.shape
    # print('past_size')
    # print(past_key_values.size())

    past_key_values = past_key_values.view(
        batch_size,
        self.pre_seq_len,
        self.n_layer*2, 
        self.n_head,
        self.n_embd
    )
    
    # zeros_tensor = torch.zeros([
    #     batch_size,
    #     self.pre_seq_len,
    #     24 * 2,
    #     self.n_head,
    #     self.n_embd
    # ]).to("cuda")
    # print('zeros_tensor')
    # print(zeros_tensor.size())

    # zeros_tensor = torch.zeros(past_key_values.size()).to('cuda')

    # past_key_values = torch.cat([zeros_tensor, past_key_values], dim=2).to("cuda")
    # print('past_size')
    # print(past_key_values.size())

    past_key_values = self.dropout(past_key_values)
    past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
    return past_key_values

def forward(
    self,
    input_ids=None,
    attention_mask=None,
    token_type_ids=None,
    position_ids=None,
    head_mask=None,
    inputs_embeds=None,
    labels=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
):
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    batch_size = input_ids.shape[0]
    past_key_values = self.get_prompt(batch_size=batch_size)
    prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
    attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

    outputs = self.bert(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        past_key_values=past_key_values,
    )

    pooled_output = outputs[1]

    pooled_output = self.dropout(pooled_output)
    logits = self.classifier(pooled_output)

    loss = None
    if not return_dict:
        output = (logits,) + outputs[2:]
        return ((loss,) + output) if loss is not None else output

    return SequenceClassifierOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

class BertPromptForSequenceClassification(BertPreTrainedModel): def init(self, config): super().init(config) self.num_labels = config.num_labels self.bert = BertModel(config) self.embeddings = self.bert.embeddings self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)

    for param in self.bert.parameters():
        param.requires_grad = False
    
    self.pre_seq_len = config.pre_seq_len
    self.n_layer = config.num_hidden_layers
    self.n_head = config.num_attention_heads
    self.n_embd = config.hidden_size // config.num_attention_heads

    self.prefix_tokens = torch.arange(self.pre_seq_len).long()
    self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)

def get_prompt(self, batch_size):
    prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
    prompts = self.prefix_encoder(prefix_tokens)
    return prompts

def forward(
    self,
    input_ids=None,
    attention_mask=None,
    token_type_ids=None,
    position_ids=None,
    head_mask=None,
    inputs_embeds=None,
    labels=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
):
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    batch_size = input_ids.shape[0]
    raw_embedding = self.embeddings(
        input_ids=input_ids,
        position_ids=position_ids,
        token_type_ids=token_type_ids,
    )
    prompts = self.get_prompt(batch_size=batch_size)
    inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
    prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
    attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

    outputs = self.bert(
        # input_ids,
        attention_mask=attention_mask,
        # token_type_ids=token_type_ids,
        # position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        # past_key_values=past_key_values,
    )

    # pooled_output = outputs[1]
    sequence_output = outputs[0]
    sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
    first_token_tensor = sequence_output[:, 0]
    pooled_output = self.bert.pooler.dense(first_token_tensor)
    pooled_output = self.bert.pooler.activation(pooled_output)

    pooled_output = self.dropout(pooled_output)
    logits = self.classifier(pooled_output)

    loss = None
    if labels is not None:
        if self.config.problem_type is None:
            if self.num_labels == 1:
                self.config.problem_type = "regression"
            elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                self.config.problem_type = "single_label_classification"
            else:
                self.config.problem_type = "multi_label_classification"

    if not return_dict:
        output = (logits,) + outputs[2:]
        return ((loss,) + output) if loss is not None else output

    return SequenceClassifierOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

image/png

Downloads last month
308
Safetensors
Model size
336M params
Tensor type
F32
·