File size: 3,929 Bytes
45b5843
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174304e
45b5843
 
 
 
4427b5b
45b5843
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78abd74
45b5843
 
 
 
174304e
45b5843
174304e
78abd74
 
45b5843
174304e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from transformers import PreTrainedModel
from .configuration_ganbert import GanBertConfig
from .gan import Generator,Discriminator
from transformers import PretrainedConfig
import logging
import datasets
from datasets import load_dataset
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_metric
import transformers
import torch
import io
import torch.nn.functional as F
import random
import numpy as np
import time
import math
import datetime
import torch.nn as nn
from torch.utils.data import Dataset,TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import (
    AutoModel,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    default_data_collator,
    set_seed,
    get_constant_schedule_with_warmup,
    Trainer,TrainingArguments,EarlyStoppingCallback)

from datasets import Dataset
import torch.nn as nn
import torch.nn.functional as F
import sys
from typing import List, Optional, Tuple, Union

class GAN(PreTrainedModel):
  config_class = GanBertConfig
  all_checkpoints=['bert-base-multilingual-cased',
                 'sagorsarker/bangla-bert-base',
                 'neuralspace-reverie/indic-transformers-bn-bert',
                 'neuralspace-reverie/indic-transformers-bn-roberta',
                 'distilbert-base-multilingual-cased',
                 'neuralspace-reverie/indic-transformers-bn-distilbert',
                 'monsoon-nlp/bangla-electra',
                 'csebuetnlp/banglabert',
                 'neuralspace-reverie/indic-transformers-bn-xlmroberta'
                 ]
  def __init__(

      self,

      config

  ):
      super().__init__(config)

      self.model_name = self.all_checkpoints[config.model_number]
      self.parent_config = AutoConfig.from_pretrained(self.model_name)
      self.hidden_size = int(self.parent_config.hidden_size)
      # Define the number and width of hidden layers
      self.hidden_levels_g = [self.hidden_size for i in range(0, config.num_hidden_layers_g)]
      self.hidden_levels_d = [self.hidden_size for i in range(0, config.num_hidden_layers_d)]
      self.label_list =  [0,1,2]
      self.class_weight = torch.tensor([10,config.pos_class_weight,5],device=config.device)
      #-------------------------------------------------
      #   Instantiate the Generator and Discriminator
      #-------------------------------------------------
      self.generator = Generator(noise_size=config.noise_size, output_size=self.hidden_size, hidden_sizes=self.hidden_levels_g, dropout_rate=config.out_dropout_rate)
      self.discriminator = Discriminator(input_size=self.hidden_size, hidden_sizes=self.hidden_levels_d,num_labels=len(self.label_list), dropout_rate=config.out_dropout_rate)
      # Put everything in the GPU if available
      # print(self.generator,self.discriminator)
      self.transformer = AutoModel.from_pretrained(self.model_name,output_attentions=True)
      self.config = config
      if config.device == 'cuda':    
        self.generator.cuda()
        self.discriminator.cuda()
        self.transformer.cuda()
  def forward(self,**kwargs):
    # Encode real data in the Transformer
    # real_batch_size = input_ids.shape[0]
    model_outputs = self.transformer(output_hidden_states = self.config.output_hidden_states,\
                                     output_attentions = self.config.output_attentions,**kwargs)
    # print('got transformer output')
    # hidden_states = torch.mean(model_outputs[0],dim=1)
    # noise = torch.zeros(real_batch_size, self.ns, device=self.dv).uniform_(0, 1).to(self.dv)
    # gen_rep = self.generator(noise)
    # disciminator_input = torch.cat([hidden_states, gen_rep], dim=0)
    # features, logits, probs = self.discriminator(disciminator_input)
    return model_outputs