import transformers import torch.nn as nn class BertClassificationModel(nn.Module): def __init__(self): super(BertClassificationModel, self).__init__() pretrained_weights="bert-base-chinese" self.bert = transformers.BertModel.from_pretrained(pretrained_weights) for param in self.bert.parameters(): param.requires_grad = True self.dense = nn.Linear(768, 3) def forward(self, input_ids,token_type_ids,attention_mask): bert_output = self.bert(input_ids=input_ids,token_type_ids=token_type_ids, attention_mask=attention_mask) bert_cls_hidden_state = bert_output[1] linear_output = self.dense(bert_cls_hidden_state) return linear_output