File size: 743 Bytes
fe7fef5
8ba144e
 
fe7fef5
 
 
 
 
 
 
8ba144e
fe7fef5
 
 
 
 
 
8ba144e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

import transformers
import torch.nn as nn

class BertClassificationModel(nn.Module):
    def __init__(self):
        super(BertClassificationModel, self).__init__()   
        pretrained_weights="bert-base-chinese"
        self.bert = transformers.BertModel.from_pretrained(pretrained_weights)
        for param in self.bert.parameters():
            param.requires_grad = True    
        self.dense = nn.Linear(768, 3)
        
    def forward(self, input_ids,token_type_ids,attention_mask):
        bert_output = self.bert(input_ids=input_ids,token_type_ids=token_type_ids, attention_mask=attention_mask)
        bert_cls_hidden_state = bert_output[1]
        linear_output = self.dense(bert_cls_hidden_state)
        return  linear_output