HighCWu's picture
init commit.
c68160d
import os
import numpy as np
import torch
import torch.nn as nn
class Generator(nn.Module):
'''Refer to https://github.com/safwankdb/Vanilla-GAN'''
def __init__(self):
super(Generator, self).__init__()
self.n_features = 128
self.n_out = 784
self.fc0 = nn.Sequential(
nn.Linear(self.n_features, 256),
nn.LeakyReLU(0.2)
)
self.fc1 = nn.Sequential(
nn.Linear(256, 512),
nn.LeakyReLU(0.2)
)
self.fc2 = nn.Sequential(
nn.Linear(512, 784),
nn.Tanh()
)
def forward(self, x):
x = self.fc0(x)
x = self.fc1(x)
x = self.fc2(x)
x = x.view(-1, 1, 28, 28)
return x
def create_mnist_inference():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
mnist = Generator()
state = torch.load(
os.path.join(
os.path.dirname(__file__),
'mnist_generator.pretrained'
),
map_location='cpu'
)
mnist.load_state_dict(state)
mnist.to(device)
mnist.eval()
@torch.inference_mode()
def mnist_generator(latents):
latents = [torch.from_numpy(latent).float().to(device) for latent in latents]
latents = torch.stack(latents)
out = mnist(latents)
outs = []
for out_i in out:
out_i = ((out_i[0] + 1) * 127.5).clamp(0,255).cpu().numpy()
out_i = np.uint8(out_i)
out_i = np.stack([out_i]*3, -1)
outs.append(out_i)
return outs
return {
'name': 'MNIST',
'generator': mnist_generator,
'latent_dim': 128,
'fps': 20,
'batch_size': 8,
'strength': 0.75,
'max_duration': 30,
'use_peak': True
}