|
|
|
import transformers |
|
import torch.nn as nn |
|
|
|
class BertClassificationModel(nn.Module): |
|
def __init__(self): |
|
super(BertClassificationModel, self).__init__() |
|
pretrained_weights="bert-base-chinese" |
|
self.bert = transformers.BertModel.from_pretrained(pretrained_weights) |
|
for param in self.bert.parameters(): |
|
param.requires_grad = True |
|
self.dense = nn.Linear(768, 3) |
|
|
|
def forward(self, input_ids,token_type_ids,attention_mask): |
|
bert_output = self.bert(input_ids=input_ids,token_type_ids=token_type_ids, attention_mask=attention_mask) |
|
bert_cls_hidden_state = bert_output[1] |
|
linear_output = self.dense(bert_cls_hidden_state) |
|
return linear_output |