|
from transformers import PreTrainedModel
|
|
from .configuration_ganbert import GanBertConfig
|
|
from .gan import Generator,Discriminator
|
|
from transformers import PretrainedConfig
|
|
import logging
|
|
import datasets
|
|
from datasets import load_dataset
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
from datasets import load_metric
|
|
import transformers
|
|
import torch
|
|
import io
|
|
import torch.nn.functional as F
|
|
import random
|
|
import numpy as np
|
|
import time
|
|
import math
|
|
import datetime
|
|
import torch.nn as nn
|
|
from torch.utils.data import Dataset,TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
|
from transformers import (
|
|
AutoModel,
|
|
AutoConfig,
|
|
AutoModelForSequenceClassification,
|
|
AutoTokenizer,
|
|
DataCollatorWithPadding,
|
|
default_data_collator,
|
|
set_seed,
|
|
get_constant_schedule_with_warmup,
|
|
Trainer,TrainingArguments,EarlyStoppingCallback)
|
|
from datasets import Dataset
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import sys
|
|
|
|
class GAN(PreTrainedModel):
|
|
config_class = GanBertConfig
|
|
all_checkpoints=['bert-base-multilingual-cased',
|
|
'sagorsarker/bangla-bert-base',
|
|
'neuralspace-reverie/indic-transformers-bn-bert',
|
|
'neuralspace-reverie/indic-transformers-bn-roberta',
|
|
'distilbert-base-multilingual-cased',
|
|
'neuralspace-reverie/indic-transformers-bn-distilbert',
|
|
'monsoon-nlp/bangla-electra',
|
|
'csebuetnlp/banglabert',
|
|
'neuralspace-reverie/indic-transformers-bn-xlmroberta'
|
|
]
|
|
def __init__(
|
|
self,
|
|
config
|
|
):
|
|
super().__init__(config)
|
|
|
|
self.model_name = self.all_checkpoints[config.model_number]
|
|
self.parent_config = AutoConfig.from_pretrained(self.model_name)
|
|
self.hidden_size = int(self.parent_config.hidden_size)
|
|
self.ns = config.noise_size
|
|
self.dv = config.device
|
|
|
|
self.hidden_levels_g = [self.hidden_size for i in range(0, config.num_hidden_layers_g)]
|
|
self.hidden_levels_d = [self.hidden_size for i in range(0, config.num_hidden_layers_d)]
|
|
self.label_list = [0,1,2]
|
|
self.class_weight = torch.tensor([10,config.pos_class_weight,5],device=config.device)
|
|
|
|
|
|
|
|
self.generator = Generator(noise_size=config.noise_size, output_size=self.hidden_size, hidden_sizes=self.hidden_levels_g, dropout_rate=config.out_dropout_rate)
|
|
self.discriminator = Discriminator(input_size=self.hidden_size, hidden_sizes=self.hidden_levels_d,num_labels=len(self.label_list), dropout_rate=config.out_dropout_rate)
|
|
|
|
|
|
self.transformer = AutoModel.from_pretrained(self.model_name,output_attentions=True)
|
|
|
|
if config.device == 'cuda':
|
|
self.generator.cuda()
|
|
self.discriminator.cuda()
|
|
self.transformer.cuda()
|
|
def forward(self,b_input_ids,b_input_mask):
|
|
|
|
real_batch_size = b_input_ids.shape[0]
|
|
model_outputs = self.transformer(b_input_ids, attention_mask=b_input_mask)
|
|
|
|
|
|
hidden_states = torch.mean(model_outputs[0],dim=1)
|
|
noise = torch.zeros(real_batch_size, self.ns, device=self.dv).uniform_(0, 1).to(self.dv)
|
|
gen_rep = self.generator(noise)
|
|
disciminator_input = torch.cat([hidden_states, gen_rep], dim=0)
|
|
features, logits, probs = self.discriminator(disciminator_input)
|
|
return model_outputs[0]
|
|
|
|
if __name__ == '__main__':
|
|
ganconfig = GanBertConfig()
|
|
clickbaitmodel = GAN(ganconfig) |