import torch import torch.nn as nn from transformers import PreTrainedModel from configs import BertVAEConfig from transformers.models.bert.modeling_bert import BertEncoder, BertModel class BertVAE(PreTrainedModel): config_class = BertVAEConfig def __init__(self, config): super().__init__(config) self.encoder = BertEncoder(config) self.bert = BertModel.from_pretrained('bert-base-uncased') self.fc_mu = nn.Linear(config.hidden_size, config.hidden_size) self.fc_var = nn.Linear(config.hidden_size, config.hidden_size) self.enc_cls = nn.Linear(config.hidden_size, config.position_num) self.dec_cls = nn.Linear(config.hidden_size, config.position_num) self.decoder = BertEncoder(config) for p in self.bert.parameters(): p.requires_grad = False def encode(self, input_ids, **kwargs): ''' x: {input_ids: (batch_size, seq_len), attention_mask: (batch_size, seq_len)} ''' x = self.bert(input_ids).last_hidden_state outputs = self.encoder(x, **kwargs) hidden_state = outputs.last_hidden_state mu = self.fc_mu(hidden_state) log_var = self.fc_var(hidden_state) return mu, log_var def encoder_cls(self, input_ids, **kwargs): ''' input_ids: {input_ids: (batch_size, seq_len)} ''' x = self.bert(input_ids).last_hidden_state outputs = self.encoder(x, **kwargs) hidden_state = outputs.last_hidden_state return self.enc_cls(hidden_state[:, 0, :]) def decoder_cls(self, z, **kwargs): ''' z: latent vector of shape (batch_size, seq_len, dim) ''' outputs = self.decoder(z, **kwargs) hidden_state = outputs.last_hidden_state return self.dec_cls(hidden_state[:, 0, :]) def reparameterize(self, mu, log_var): std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) return mu + eps * std def decode(self, z, **kwargs): ''' z: latent vector of shape (batch_size, seq_len, dim) ''' outputs = self.decoder(z, **kwargs) return outputs.last_hidden_state def forward(self, input_ids, position=None, **kwargs): mu, log_var = self.encode(**input_ids, **kwargs) z = self.reparameterize(mu, log_var) return self.decode(z, **kwargs), mu, log_var def _elbo(self, x, x_hat, mu, log_var): ''' Given input x, logits, mu, log_var, compute the negative ELBO x: input tensor of shape (batch_size, seq_len, dim) logits: logits tensor of shape (batch_size, seq_len, dim) mu: mean tensor of shape (batch_size, seq_len, dim) log_var: log variance tensor of shape (batch_size, seq_len, dim) ''' recon_loss = nn.functional.mse_loss(x_hat, x, reduction='mean') kl_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())) return recon_loss + kl_loss*0.1 def elbo(self, input_ids, **kwargs): ''' Given input x, compute the ELBO x: input tensor of shape (batch_size, seq_len, dim) ''' x = self.bert(input_ids, **kwargs).last_hidden_state outputs = self.encoder(x, **kwargs) hidden_state = outputs.last_hidden_state mu = self.fc_mu(hidden_state) log_var = self.fc_var(hidden_state) z = self.reparameterize(mu, log_var) outputs = self.decoder(z, **kwargs) x_hat = outputs.last_hidden_state return self._elbo(x, x_hat, mu, log_var) def reconstruct(self, input_ids, **kwargs): ''' Given input_ids, reconstruct x x: input tensor of shape (batch_size, seq_len, dim) ''' return self.forward(input_ids, **kwargs)[0] def sample(self, num_samples, device, **kwargs): ''' Given input x, generate a sample x: input tensor of shape (batch_size, seq_len, dim) ''' z = torch.randn(num_samples, self.config.max_position_embeddings, self.config.hidden_size).to(device) return self.decode(z, **kwargs)