Updated ganbert.py
Browse files- ganbert.py +3 -4
ganbert.py
CHANGED
@@ -58,8 +58,6 @@ class GAN(PreTrainedModel):
|
|
58 |
self.model_name = self.all_checkpoints[config.model_number]
|
59 |
self.parent_config = AutoConfig.from_pretrained(self.model_name)
|
60 |
self.hidden_size = int(self.parent_config.hidden_size)
|
61 |
-
self.ns = config.noise_size
|
62 |
-
self.dv = config.device
|
63 |
# Define the number and width of hidden layers
|
64 |
self.hidden_levels_g = [self.hidden_size for i in range(0, config.num_hidden_layers_g)]
|
65 |
self.hidden_levels_d = [self.hidden_size for i in range(0, config.num_hidden_layers_d)]
|
@@ -73,7 +71,7 @@ class GAN(PreTrainedModel):
|
|
73 |
# Put everything in the GPU if available
|
74 |
# print(self.generator,self.discriminator)
|
75 |
self.transformer = AutoModel.from_pretrained(self.model_name,output_attentions=True)
|
76 |
-
|
77 |
if config.device == 'cuda':
|
78 |
self.generator.cuda()
|
79 |
self.discriminator.cuda()
|
@@ -81,7 +79,8 @@ class GAN(PreTrainedModel):
|
|
81 |
def forward(self,**kwargs):
|
82 |
# Encode real data in the Transformer
|
83 |
# real_batch_size = input_ids.shape[0]
|
84 |
-
model_outputs = self.transformer(
|
|
|
85 |
# print('got transformer output')
|
86 |
# hidden_states = torch.mean(model_outputs[0],dim=1)
|
87 |
# noise = torch.zeros(real_batch_size, self.ns, device=self.dv).uniform_(0, 1).to(self.dv)
|
|
|
58 |
self.model_name = self.all_checkpoints[config.model_number]
|
59 |
self.parent_config = AutoConfig.from_pretrained(self.model_name)
|
60 |
self.hidden_size = int(self.parent_config.hidden_size)
|
|
|
|
|
61 |
# Define the number and width of hidden layers
|
62 |
self.hidden_levels_g = [self.hidden_size for i in range(0, config.num_hidden_layers_g)]
|
63 |
self.hidden_levels_d = [self.hidden_size for i in range(0, config.num_hidden_layers_d)]
|
|
|
71 |
# Put everything in the GPU if available
|
72 |
# print(self.generator,self.discriminator)
|
73 |
self.transformer = AutoModel.from_pretrained(self.model_name,output_attentions=True)
|
74 |
+
self.config = config
|
75 |
if config.device == 'cuda':
|
76 |
self.generator.cuda()
|
77 |
self.discriminator.cuda()
|
|
|
79 |
def forward(self,**kwargs):
|
80 |
# Encode real data in the Transformer
|
81 |
# real_batch_size = input_ids.shape[0]
|
82 |
+
model_outputs = self.transformer(output_hidden_states = self.config.output_hidden_states,\
|
83 |
+
output_attentions = self.config.output_attentions,**kwargs)
|
84 |
# print('got transformer output')
|
85 |
# hidden_states = torch.mean(model_outputs[0],dim=1)
|
86 |
# noise = torch.zeros(real_batch_size, self.ns, device=self.dv).uniform_(0, 1).to(self.dv)
|