File size: 1,913 Bytes
4230aba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn

import random
import numpy as np

import numpy as np

import pandas as pd

import torch.nn.functional as F

from transformers import BertModel, PreTrainedModel
from configuration_cbert import BertCustomConfig

import torch.optim as optim

class BertSentiment(PreTrainedModel):
    config_class = BertCustomConfig
    def __init__(self, config, weight_path=None):
        super().__init__(config)
        self.config = config
        self.num_labels = self.config.hyperparams["num_labels"]
        # self.bert = BertModel.from_pretrained('yiyanghkust/finbert-tone')
        if weight_path:
            self.bert = BertModel.from_pretrained(weight_path)
        else:
            self.bert = BertModel(self.config)

        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        self.hidden = nn.Linear(self.config.hidden_size, self.config.hidden_size)
        self.classifier = nn.Linear(self.config.hidden_size, self.config.hyperparams["num_labels"])
        # self.classifier2 = nn.Linear(dense_size + meta_size, num_labels)
        nn.init.xavier_normal_(self.hidden.weight)
        nn.init.xavier_normal_(self.classifier.weight)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, graphEmbeddings=None):
        # _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, return_dict=False)
        output, ctoken = self.bert(input_ids, token_type_ids, attention_mask, return_dict=False)
        pooled_output = torch.mean(output, 1)
        pooled_output = self.hidden(pooled_output)
        pooled_output = self.dropout(pooled_output)
        pooled_output = F.relu(pooled_output)
        logits = self.classifier(pooled_output)
        # dense1 = self.classifier(pooled_output)
        # concatl = torch.cat((dense1, meta_data.float()), 1)
        # logits = self.classifier2(concatl)

        return logits