import json import os import numpy as np import torch import torch.nn as nn from tqdm import trange from transformers import ElectraModel, AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoConfig from torch.utils.data import DataLoader, TensorDataset from transformers import get_linear_schedule_with_warmup from transformers import AdamW from datasets import load_metric from sklearn.metrics import f1_score import pandas as pd import copy # from utils import evaluation, evaluation_f1 from torch.nn import functional as F import re from config import entity_property_pair from tqdm import tqdm from datasets import Dataset import torch.nn as nn from transformers import AutoModelForSequenceClassification from transformers import ElectraModel class Classifier(nn.Module): def __init__(self, base_model, num_labels, device, tokenizer): super(Classifier, self).__init__() self.num_labels = num_labels self.device = device self.electra = ElectraModel.from_pretrained('beomi/KcELECTRA-base', num_labels=2) self.electra.resize_token_embeddings(len(tokenizer)) self.fc1 = nn.Linear(self.electra.config.hidden_size, 256) self.fc2 = nn.Linear(self.electra.config.hidden_size, 512) self.fc3 = nn.Linear(256+512, 2) self.dropout = nn.Dropout(0.1) def forward(self, input_ids, attention_mask, entity_mask): outputs = self.electra(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) last_hidden_state = outputs.last_hidden_state masked_last_hidden = self.entity_average(last_hidden_state, entity_mask) masked_last_hidden = self.fc2(masked_last_hidden) last_hidden_state = self.fc1(last_hidden_state) entity_outputs = torch.cat([last_hidden_state[:, 0, :] , masked_last_hidden], dim=-1) outputs = torch.tanh(entity_outputs) outputs = self.dropout(outputs) outputs = self.fc3(outputs) return outputs @staticmethod def entity_average(hidden_output, e_mask): e_mask_unsqueeze = e_mask.unsqueeze(1) # [b, 1, j-i+1] length_tensor = (e_mask != 0).sum(dim=1).unsqueeze(1) # [batch_size, 1] # [b, 1, j-i+1] * [b, j-i+1, dim] = [b, 1, dim] -> [b, dim] sum_vector = torch.bmm(e_mask_unsqueeze.float(), hidden_output).squeeze(1) avg_vector = sum_vector.float() / length_tensor.float() # broadcasting return avg_vector