Motahar commited on
Commit
78abd74
1 Parent(s): 174304e

Updated ganbert.py

Browse files
Files changed (1) hide show
  1. ganbert.py +3 -4
ganbert.py CHANGED
@@ -58,8 +58,6 @@ class GAN(PreTrainedModel):
58
  self.model_name = self.all_checkpoints[config.model_number]
59
  self.parent_config = AutoConfig.from_pretrained(self.model_name)
60
  self.hidden_size = int(self.parent_config.hidden_size)
61
- self.ns = config.noise_size
62
- self.dv = config.device
63
  # Define the number and width of hidden layers
64
  self.hidden_levels_g = [self.hidden_size for i in range(0, config.num_hidden_layers_g)]
65
  self.hidden_levels_d = [self.hidden_size for i in range(0, config.num_hidden_layers_d)]
@@ -73,7 +71,7 @@ class GAN(PreTrainedModel):
73
  # Put everything in the GPU if available
74
  # print(self.generator,self.discriminator)
75
  self.transformer = AutoModel.from_pretrained(self.model_name,output_attentions=True)
76
-
77
  if config.device == 'cuda':
78
  self.generator.cuda()
79
  self.discriminator.cuda()
@@ -81,7 +79,8 @@ class GAN(PreTrainedModel):
81
  def forward(self,**kwargs):
82
  # Encode real data in the Transformer
83
  # real_batch_size = input_ids.shape[0]
84
- model_outputs = self.transformer(**kwargs)
 
85
  # print('got transformer output')
86
  # hidden_states = torch.mean(model_outputs[0],dim=1)
87
  # noise = torch.zeros(real_batch_size, self.ns, device=self.dv).uniform_(0, 1).to(self.dv)
 
58
  self.model_name = self.all_checkpoints[config.model_number]
59
  self.parent_config = AutoConfig.from_pretrained(self.model_name)
60
  self.hidden_size = int(self.parent_config.hidden_size)
 
 
61
  # Define the number and width of hidden layers
62
  self.hidden_levels_g = [self.hidden_size for i in range(0, config.num_hidden_layers_g)]
63
  self.hidden_levels_d = [self.hidden_size for i in range(0, config.num_hidden_layers_d)]
 
71
  # Put everything in the GPU if available
72
  # print(self.generator,self.discriminator)
73
  self.transformer = AutoModel.from_pretrained(self.model_name,output_attentions=True)
74
+ self.config = config
75
  if config.device == 'cuda':
76
  self.generator.cuda()
77
  self.discriminator.cuda()
 
79
  def forward(self,**kwargs):
80
  # Encode real data in the Transformer
81
  # real_batch_size = input_ids.shape[0]
82
+ model_outputs = self.transformer(output_hidden_states = self.config.output_hidden_states,\
83
+ output_attentions = self.config.output_attentions,**kwargs)
84
  # print('got transformer output')
85
  # hidden_states = torch.mean(model_outputs[0],dim=1)
86
  # noise = torch.zeros(real_batch_size, self.ns, device=self.dv).uniform_(0, 1).to(self.dv)