Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| class Generator(nn.Module): | |
| '''Refer to https://github.com/safwankdb/Vanilla-GAN''' | |
| def __init__(self): | |
| super(Generator, self).__init__() | |
| self.n_features = 128 | |
| self.n_out = 784 | |
| self.fc0 = nn.Sequential( | |
| nn.Linear(self.n_features, 256), | |
| nn.LeakyReLU(0.2) | |
| ) | |
| self.fc1 = nn.Sequential( | |
| nn.Linear(256, 512), | |
| nn.LeakyReLU(0.2) | |
| ) | |
| self.fc2 = nn.Sequential( | |
| nn.Linear(512, 784), | |
| nn.Tanh() | |
| ) | |
| def forward(self, x): | |
| x = self.fc0(x) | |
| x = self.fc1(x) | |
| x = self.fc2(x) | |
| x = x.view(-1, 1, 28, 28) | |
| return x | |
| def create_mnist_inference(): | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| mnist = Generator() | |
| state = torch.load( | |
| os.path.join( | |
| os.path.dirname(__file__), | |
| 'mnist_generator.pretrained' | |
| ), | |
| map_location='cpu' | |
| ) | |
| mnist.load_state_dict(state) | |
| mnist.to(device) | |
| mnist.eval() | |
| def mnist_generator(latents): | |
| latents = [torch.from_numpy(latent).float().to(device) for latent in latents] | |
| latents = torch.stack(latents) | |
| out = mnist(latents) | |
| outs = [] | |
| for out_i in out: | |
| out_i = ((out_i[0] + 1) * 127.5).clamp(0,255).cpu().numpy() | |
| out_i = np.uint8(out_i) | |
| out_i = np.stack([out_i]*3, -1) | |
| outs.append(out_i) | |
| return outs | |
| return { | |
| 'name': 'MNIST', | |
| 'generator': mnist_generator, | |
| 'latent_dim': 128, | |
| 'fps': 20, | |
| 'batch_size': 8, | |
| 'strength': 0.75, | |
| 'max_duration': 30, | |
| 'use_peak': True | |
| } | |