Motahar commited on
Commit
174304e
1 Parent(s): 4427b5b

Updated ganbert.py

Browse files
Files changed (1) hide show
  1. ganbert.py +10 -24
ganbert.py CHANGED
@@ -30,6 +30,7 @@ from transformers import (
30
  set_seed,
31
  get_constant_schedule_with_warmup,
32
  Trainer,TrainingArguments,EarlyStoppingCallback)
 
33
  from datasets import Dataset
34
  import torch.nn as nn
35
  import torch.nn.functional as F
@@ -77,29 +78,14 @@ class GAN(PreTrainedModel):
77
  self.generator.cuda()
78
  self.discriminator.cuda()
79
  self.transformer.cuda()
80
- def forward(self,
81
- input_ids: Optional[torch.Tensor] = None,
82
- attention_mask: Optional[torch.Tensor] = None,
83
- token_type_ids: Optional[torch.Tensor] = None,
84
- position_ids: Optional[torch.Tensor] = None,
85
- head_mask: Optional[torch.Tensor] = None,
86
- inputs_embeds: Optional[torch.Tensor] = None,
87
- encoder_hidden_states: Optional[torch.Tensor] = None,
88
- encoder_attention_mask: Optional[torch.Tensor] = None,
89
- past_key_values: Optional[List[torch.FloatTensor]] = None,
90
- use_cache: Optional[bool] = None,
91
- output_attentions: Optional[bool] = None,
92
- output_hidden_states: Optional[bool] = None,
93
- return_dict: Optional[bool] = None,
94
- ):
95
  # Encode real data in the Transformer
96
- real_batch_size = input_ids.shape[0]
97
- model_outputs = self.transformer(input_ids, attention_mask=attention_mask)
98
  # print('got transformer output')
99
-
100
- hidden_states = torch.mean(model_outputs[0],dim=1)
101
- noise = torch.zeros(real_batch_size, self.ns, device=self.dv).uniform_(0, 1).to(self.dv)
102
- gen_rep = self.generator(noise)
103
- disciminator_input = torch.cat([hidden_states, gen_rep], dim=0)
104
- features, logits, probs = self.discriminator(disciminator_input)
105
- return model_outputs[0]
 
30
  set_seed,
31
  get_constant_schedule_with_warmup,
32
  Trainer,TrainingArguments,EarlyStoppingCallback)
33
+
34
  from datasets import Dataset
35
  import torch.nn as nn
36
  import torch.nn.functional as F
 
78
  self.generator.cuda()
79
  self.discriminator.cuda()
80
  self.transformer.cuda()
81
+ def forward(self,**kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # Encode real data in the Transformer
83
+ # real_batch_size = input_ids.shape[0]
84
+ model_outputs = self.transformer(**kwargs)
85
  # print('got transformer output')
86
+ # hidden_states = torch.mean(model_outputs[0],dim=1)
87
+ # noise = torch.zeros(real_batch_size, self.ns, device=self.dv).uniform_(0, 1).to(self.dv)
88
+ # gen_rep = self.generator(noise)
89
+ # disciminator_input = torch.cat([hidden_states, gen_rep], dim=0)
90
+ # features, logits, probs = self.discriminator(disciminator_input)
91
+ return model_outputs