Spaces:
Runtime error
Runtime error
| from torch import nn, optim | |
| import torch | |
| class Photoz_network(nn.Module): | |
| def __init__(self, num_gauss=10, dropout_prob=0): | |
| super(Photoz_network, self).__init__() | |
| self.features = nn.Sequential( | |
| nn.Linear(6, 10), | |
| nn.Dropout(dropout_prob), | |
| nn.ReLU(), | |
| nn.Linear(10, 20), | |
| nn.Dropout(dropout_prob), | |
| nn.ReLU(), | |
| nn.Linear(20, 50), | |
| nn.Dropout(dropout_prob), | |
| nn.ReLU(), | |
| nn.Linear(50, 20), | |
| nn.Dropout(dropout_prob), | |
| nn.ReLU(), | |
| nn.Linear(20, 10) | |
| ) | |
| self.measure_mu = nn.Sequential( | |
| nn.Linear(10, 20), | |
| nn.Dropout(dropout_prob), | |
| nn.ReLU(), | |
| nn.Linear(20, num_gauss) | |
| ) | |
| self.measure_coeffs = nn.Sequential( | |
| nn.Linear(10, 20), | |
| nn.Dropout(dropout_prob), | |
| nn.ReLU(), | |
| nn.Linear(20, num_gauss) | |
| ) | |
| self.measure_sigma = nn.Sequential( | |
| nn.Linear(10, 20), | |
| nn.Dropout(dropout_prob), | |
| nn.ReLU(), | |
| nn.Linear(20, num_gauss) | |
| ) | |
| def forward(self, x): | |
| f = self.features(x) | |
| mu = self.measure_mu(f) | |
| sigma = self.measure_sigma(f) | |
| logmix_coeff = self.measure_coeffs(f) | |
| logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, 1)[:,None] | |
| return mu, sigma, logmix_coeff | |