Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class MLPEncoder(nn.Module): | |
| def __init__(self, vocab_size, num_topic, hidden_dim, dropout): | |
| super().__init__() | |
| self.fc11 = nn.Linear(vocab_size, hidden_dim) | |
| self.fc12 = nn.Linear(hidden_dim, hidden_dim) | |
| self.fc21 = nn.Linear(hidden_dim, num_topic) | |
| self.fc22 = nn.Linear(hidden_dim, num_topic) | |
| self.fc1_drop = nn.Dropout(dropout) | |
| self.z_drop = nn.Dropout(dropout) | |
| self.mean_bn = nn.BatchNorm1d(num_topic, affine=True) | |
| self.mean_bn.weight.requires_grad = False | |
| self.logvar_bn = nn.BatchNorm1d(num_topic, affine=True) | |
| self.logvar_bn.weight.requires_grad = False | |
| def reparameterize(self, mu, logvar): | |
| if self.training: | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return mu + (eps * std) | |
| else: | |
| return mu | |
| def forward(self, x): | |
| e1 = F.softplus(self.fc11(x)) | |
| e1 = F.softplus(self.fc12(e1)) | |
| e1 = self.fc1_drop(e1) | |
| mu = self.mean_bn(self.fc21(e1)) | |
| logvar = self.logvar_bn(self.fc22(e1)) | |
| theta = self.reparameterize(mu, logvar) | |
| theta = F.softmax(theta, dim=1) | |
| theta = self.z_drop(theta) | |
| return theta, mu, logvar | |