Edit model card
YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Checkpoint model trained from this source code "https://github.com/THUDM/P-tuning-v2" import torch from torch._C import NoopLogger import torch.nn import torch.nn.functional as F from torch import Tensor

from transformers import BertModel, BertPreTrainedModel from transformers import RobertaModel, RobertaPreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutput, Seq2SeqLMOutput

import torch

class PrefixEncoder(torch.nn.Module): r''' The torch.nn model to encode the prefix

Input shape: (batch-size, prefix-length)

Output shape: (batch-size, prefix-length, 2*layers*hidden)
'''
def __init__(self, config):
    super().__init__()
    self.prefix_projection = config.prefix_projection
    # config.num_hidden_layers = 12

    if self.prefix_projection:
        # Use a two-layer MLP to encode the prefix
        self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
        self.trans = torch.nn.Sequential(
            torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
            torch.nn.Tanh(),
            torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
        )
    else:
        self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)

def forward(self, prefix: torch.Tensor):
    if self.prefix_projection:
        prefix_tokens = self.embedding(prefix)
        past_key_values = self.trans(prefix_tokens)
    else:
        past_key_values = self.embedding(prefix)

    # size = 1024*12*2
    # past_key_values = past_key_values[:, :, : int(size)]

    return past_key_values

class BertPrefixForSequenceClassification(BertPreTrainedModel): def init(self, config): super().init(config) self.num_labels = config.num_labels self.config = config self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)

    for param in self.bert.parameters():
        param.requires_grad = False
    
    self.pre_seq_len = config.pre_seq_len
    self.n_layer = config.num_hidden_layers
    self.n_head = config.num_attention_heads
    self.n_embd = config.hidden_size // config.num_attention_heads

    self.prefix_tokens = torch.arange(self.pre_seq_len).long()
    self.prefix_encoder = PrefixEncoder(config)

    bert_param = 0
    for name, param in self.bert.named_parameters():
        bert_param += param.numel()
    all_param = 0
    for name, param in self.named_parameters():
        all_param += param.numel()
    total_param = all_param - bert_param
    print('total param is {}'.format(total_param)) # 9860105

def get_prompt(self, batch_size):
    prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
    past_key_values = self.prefix_encoder(prefix_tokens)
    # bsz, seqlen, _ = past_key_values.shape
    # print('past_size')
    # print(past_key_values.size())

    past_key_values = past_key_values.view(
        batch_size,
        self.pre_seq_len,
        self.n_layer*2, 
        self.n_head,
        self.n_embd
    )
    
    # zeros_tensor = torch.zeros([
    #     batch_size,
    #     self.pre_seq_len,
    #     24 * 2,
    #     self.n_head,
    #     self.n_embd
    # ]).to("cuda")
    # print('zeros_tensor')
    # print(zeros_tensor.size())

    # zeros_tensor = torch.zeros(past_key_values.size()).to('cuda')

    # past_key_values = torch.cat([zeros_tensor, past_key_values], dim=2).to("cuda")
    # print('past_size')
    # print(past_key_values.size())

    past_key_values = self.dropout(past_key_values)
    past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
    return past_key_values

def forward(
    self,
    input_ids=None,
    attention_mask=None,
    token_type_ids=None,
    position_ids=None,
    head_mask=None,
    inputs_embeds=None,
    labels=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
):
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    batch_size = input_ids.shape[0]
    past_key_values = self.get_prompt(batch_size=batch_size)
    prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
    attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

    outputs = self.bert(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        past_key_values=past_key_values,
    )

    pooled_output = outputs[1]

    pooled_output = self.dropout(pooled_output)
    logits = self.classifier(pooled_output)

    loss = None
    if not return_dict:
        output = (logits,) + outputs[2:]
        return ((loss,) + output) if loss is not None else output

    return SequenceClassifierOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

class BertPromptForSequenceClassification(BertPreTrainedModel): def init(self, config): super().init(config) self.num_labels = config.num_labels self.bert = BertModel(config) self.embeddings = self.bert.embeddings self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)

    for param in self.bert.parameters():
        param.requires_grad = False
    
    self.pre_seq_len = config.pre_seq_len
    self.n_layer = config.num_hidden_layers
    self.n_head = config.num_attention_heads
    self.n_embd = config.hidden_size // config.num_attention_heads

    self.prefix_tokens = torch.arange(self.pre_seq_len).long()
    self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)

def get_prompt(self, batch_size):
    prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
    prompts = self.prefix_encoder(prefix_tokens)
    return prompts

def forward(
    self,
    input_ids=None,
    attention_mask=None,
    token_type_ids=None,
    position_ids=None,
    head_mask=None,
    inputs_embeds=None,
    labels=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
):
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    batch_size = input_ids.shape[0]
    raw_embedding = self.embeddings(
        input_ids=input_ids,
        position_ids=position_ids,
        token_type_ids=token_type_ids,
    )
    prompts = self.get_prompt(batch_size=batch_size)
    inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
    prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
    attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

    outputs = self.bert(
        # input_ids,
        attention_mask=attention_mask,
        # token_type_ids=token_type_ids,
        # position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        # past_key_values=past_key_values,
    )

    # pooled_output = outputs[1]
    sequence_output = outputs[0]
    sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
    first_token_tensor = sequence_output[:, 0]
    pooled_output = self.bert.pooler.dense(first_token_tensor)
    pooled_output = self.bert.pooler.activation(pooled_output)

    pooled_output = self.dropout(pooled_output)
    logits = self.classifier(pooled_output)

    loss = None
    if labels is not None:
        if self.config.problem_type is None:
            if self.num_labels == 1:
                self.config.problem_type = "regression"
            elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                self.config.problem_type = "single_label_classification"
            else:
                self.config.problem_type = "multi_label_classification"

    if not return_dict:
        output = (logits,) + outputs[2:]
        return ((loss,) + output) if loss is not None else output

    return SequenceClassifierOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

image/png

Downloads last month
11
Safetensors
Model size
336M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.