cbert / model_cbert.py
yonichi's picture
add model
4230aba
raw history blame
No virus
1.91 kB
import torch
import torch.nn as nn
import random
import numpy as np
import numpy as np
import pandas as pd
import torch.nn.functional as F
from transformers import BertModel, PreTrainedModel
from configuration_cbert import BertCustomConfig
import torch.optim as optim
class BertSentiment(PreTrainedModel):
config_class = BertCustomConfig
def __init__(self, config, weight_path=None):
super().__init__(config)
self.config = config
self.num_labels = self.config.hyperparams["num_labels"]
# self.bert = BertModel.from_pretrained('yiyanghkust/finbert-tone')
if weight_path:
self.bert = BertModel.from_pretrained(weight_path)
else:
self.bert = BertModel(self.config)
self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
self.hidden = nn.Linear(self.config.hidden_size, self.config.hidden_size)
self.classifier = nn.Linear(self.config.hidden_size, self.config.hyperparams["num_labels"])
# self.classifier2 = nn.Linear(dense_size + meta_size, num_labels)
nn.init.xavier_normal_(self.hidden.weight)
nn.init.xavier_normal_(self.classifier.weight)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, graphEmbeddings=None):
# _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, return_dict=False)
output, ctoken = self.bert(input_ids, token_type_ids, attention_mask, return_dict=False)
pooled_output = torch.mean(output, 1)
pooled_output = self.hidden(pooled_output)
pooled_output = self.dropout(pooled_output)
pooled_output = F.relu(pooled_output)
logits = self.classifier(pooled_output)
# dense1 = self.classifier(pooled_output)
# concatl = torch.cat((dense1, meta_data.float()), 1)
# logits = self.classifier2(concatl)
return logits