Motahar commited on
Commit
4d07349
1 Parent(s): 3e8bd23

Updated ganbert.py

Browse files
Files changed (1) hide show
  1. ganbert.py +17 -7
ganbert.py CHANGED
@@ -76,10 +76,24 @@ class GAN(PreTrainedModel):
76
  self.generator.cuda()
77
  self.discriminator.cuda()
78
  self.transformer.cuda()
79
- def forward(self,b_input_ids,b_input_mask):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # Encode real data in the Transformer
81
- real_batch_size = b_input_ids.shape[0]
82
- model_outputs = self.transformer(b_input_ids, attention_mask=b_input_mask)
83
  # print('got transformer output')
84
 
85
  hidden_states = torch.mean(model_outputs[0],dim=1)
@@ -88,7 +102,3 @@ class GAN(PreTrainedModel):
88
  disciminator_input = torch.cat([hidden_states, gen_rep], dim=0)
89
  features, logits, probs = self.discriminator(disciminator_input)
90
  return model_outputs[0]
91
-
92
- if __name__ == '__main__':
93
- ganconfig = GanBertConfig()
94
- clickbaitmodel = GAN(ganconfig)
 
76
  self.generator.cuda()
77
  self.discriminator.cuda()
78
  self.transformer.cuda()
79
+ def forward(self,
80
+ input_ids: Optional[torch.Tensor] = None,
81
+ attention_mask: Optional[torch.Tensor] = None,
82
+ token_type_ids: Optional[torch.Tensor] = None,
83
+ position_ids: Optional[torch.Tensor] = None,
84
+ head_mask: Optional[torch.Tensor] = None,
85
+ inputs_embeds: Optional[torch.Tensor] = None,
86
+ encoder_hidden_states: Optional[torch.Tensor] = None,
87
+ encoder_attention_mask: Optional[torch.Tensor] = None,
88
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
89
+ use_cache: Optional[bool] = None,
90
+ output_attentions: Optional[bool] = None,
91
+ output_hidden_states: Optional[bool] = None,
92
+ return_dict: Optional[bool] = None,
93
+ ):
94
  # Encode real data in the Transformer
95
+ real_batch_size = input_ids.shape[0]
96
+ model_outputs = self.transformer(input_ids, attention_mask=attention_mask)
97
  # print('got transformer output')
98
 
99
  hidden_states = torch.mean(model_outputs[0],dim=1)
 
102
  disciminator_input = torch.cat([hidden_states, gen_rep], dim=0)
103
  features, logits, probs = self.discriminator(disciminator_input)
104
  return model_outputs[0]