Yewon commited on
Commit
f4b9f63
·
1 Parent(s): 800a5b6

feat: first dist

Browse files
vae.bin/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertVAE"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "classifier_dropout": null,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 768,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 3072,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 512,
14
+ "model_type": "bert_vae",
15
+ "num_attention_heads": 12,
16
+ "num_hidden_layers": 3,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "position_num": 4,
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.19.2",
22
+ "type_vocab_size": 2,
23
+ "use_cache": true,
24
+ "vocab_size": 30522
25
+ }
vae.bin/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:608e15a087931b4ecb1ecead87c0830cbaead9094062c191ee7fc6a4e581ad33
3
+ size 612894285
vae.bin/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbae0b4d5681105f7086b42c6969b2a30f31bac7ab4f5b16b61231a0f068bab2
3
+ size 3195
vae.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from configs import BertVAEConfig
5
+ from transformers.models.bert.modeling_bert import BertEncoder, BertModel
6
+
7
+
8
+ class BertVAE(PreTrainedModel):
9
+ config_class = BertVAEConfig
10
+
11
+ def __init__(self, config):
12
+ super().__init__(config)
13
+ self.encoder = BertEncoder(config)
14
+ self.bert = BertModel.from_pretrained('bert-base-uncased')
15
+ self.fc_mu = nn.Linear(config.hidden_size, config.hidden_size)
16
+ self.fc_var = nn.Linear(config.hidden_size, config.hidden_size)
17
+ self.enc_cls = nn.Linear(config.hidden_size, config.position_num)
18
+ self.dec_cls = nn.Linear(config.hidden_size, config.position_num)
19
+ self.decoder = BertEncoder(config)
20
+
21
+ for p in self.bert.parameters():
22
+ p.requires_grad = False
23
+
24
+
25
+ def encode(self, input_ids, **kwargs):
26
+ '''
27
+ x: {input_ids: (batch_size, seq_len), attention_mask: (batch_size, seq_len)}
28
+ '''
29
+
30
+ x = self.bert(input_ids).last_hidden_state
31
+ outputs = self.encoder(x, **kwargs)
32
+ hidden_state = outputs.last_hidden_state
33
+ mu = self.fc_mu(hidden_state)
34
+ log_var = self.fc_var(hidden_state)
35
+ return mu, log_var
36
+
37
+
38
+ def encoder_cls(self, input_ids, **kwargs):
39
+ '''
40
+ input_ids: {input_ids: (batch_size, seq_len)}
41
+ '''
42
+ x = self.bert(input_ids).last_hidden_state
43
+ outputs = self.encoder(x, **kwargs)
44
+ hidden_state = outputs.last_hidden_state
45
+ return self.enc_cls(hidden_state[:, 0, :])
46
+
47
+
48
+ def decoder_cls(self, z, **kwargs):
49
+ '''
50
+ z: latent vector of shape (batch_size, seq_len, dim)
51
+ '''
52
+ outputs = self.decoder(z, **kwargs)
53
+ hidden_state = outputs.last_hidden_state
54
+ return self.dec_cls(hidden_state[:, 0, :])
55
+
56
+
57
+ def reparameterize(self, mu, log_var):
58
+ std = torch.exp(0.5 * log_var)
59
+ eps = torch.randn_like(std)
60
+ return mu + eps * std
61
+
62
+
63
+ def decode(self, z, **kwargs):
64
+ '''
65
+ z: latent vector of shape (batch_size, seq_len, dim)
66
+ '''
67
+ outputs = self.decoder(z, **kwargs)
68
+ return outputs.last_hidden_state
69
+
70
+
71
+ def forward(self, input_ids, position=None, **kwargs):
72
+ mu, log_var = self.encode(**input_ids, **kwargs)
73
+ z = self.reparameterize(mu, log_var)
74
+ return self.decode(z, **kwargs), mu, log_var
75
+
76
+
77
+ def _elbo(self, x, x_hat, mu, log_var):
78
+ '''
79
+ Given input x, logits, mu, log_var, compute the negative ELBO
80
+ x: input tensor of shape (batch_size, seq_len, dim)
81
+ logits: logits tensor of shape (batch_size, seq_len, dim)
82
+ mu: mean tensor of shape (batch_size, seq_len, dim)
83
+ log_var: log variance tensor of shape (batch_size, seq_len, dim)
84
+ '''
85
+ recon_loss = nn.functional.mse_loss(x_hat, x, reduction='mean')
86
+ kl_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()))
87
+ return recon_loss + kl_loss*0.1
88
+
89
+
90
+ def elbo(self, input_ids, **kwargs):
91
+ '''
92
+ Given input x, compute the ELBO
93
+ x: input tensor of shape (batch_size, seq_len, dim)
94
+ '''
95
+ x = self.bert(input_ids, **kwargs).last_hidden_state
96
+ outputs = self.encoder(x, **kwargs)
97
+ hidden_state = outputs.last_hidden_state
98
+ mu = self.fc_mu(hidden_state)
99
+ log_var = self.fc_var(hidden_state)
100
+ z = self.reparameterize(mu, log_var)
101
+ outputs = self.decoder(z, **kwargs)
102
+ x_hat = outputs.last_hidden_state
103
+ return self._elbo(x, x_hat, mu, log_var)
104
+
105
+
106
+ def reconstruct(self, input_ids, **kwargs):
107
+ '''
108
+ Given input_ids, reconstruct x
109
+ x: input tensor of shape (batch_size, seq_len, dim)
110
+ '''
111
+ return self.forward(input_ids, **kwargs)[0]
112
+
113
+
114
+
115
+ def sample(self, num_samples, device, **kwargs):
116
+ '''
117
+ Given input x, generate a sample
118
+ x: input tensor of shape (batch_size, seq_len, dim)
119
+ '''
120
+ z = torch.randn(num_samples, self.config.max_position_embeddings, self.config.hidden_size).to(device)
121
+ return self.decode(z, **kwargs)
vae_config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertConfig
2
+ from typing import List
3
+
4
+
5
+ class BertVAEConfig(BertConfig):
6
+ model_type = "bert_vae"
7
+ is_encoder_decoder = True
8
+
9
+ def __init__(
10
+ self,
11
+ num_hidden_layers=3,
12
+ position_num=4,
13
+ **kwargs,
14
+ ):
15
+ super().__init__(**kwargs)
16
+ self.num_hidden_layers = num_hidden_layers
17
+ self.position_num = position_num
18
+