File size: 1,892 Bytes
c68160d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import numpy as np
import torch
import torch.nn as nn


class Generator(nn.Module):
    '''Refer to https://github.com/safwankdb/Vanilla-GAN'''
    def __init__(self):
        super(Generator, self).__init__()
        self.n_features = 128
        self.n_out = 784
        self.fc0 = nn.Sequential(
                    nn.Linear(self.n_features, 256),
                    nn.LeakyReLU(0.2)
                    )
        self.fc1 = nn.Sequential(
                    nn.Linear(256, 512),
                    nn.LeakyReLU(0.2)
                    )
        self.fc2 = nn.Sequential(
                    nn.Linear(512, 784),
                    nn.Tanh()
                    )
    def forward(self, x):
        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = x.view(-1, 1, 28, 28)
        return x


def create_mnist_inference():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    mnist = Generator()
    state = torch.load(
        os.path.join(
            os.path.dirname(__file__),
            'mnist_generator.pretrained'
        ), 
        map_location='cpu'
    )
    mnist.load_state_dict(state)
    mnist.to(device)
    mnist.eval()

    @torch.inference_mode()
    def mnist_generator(latents):
        latents = [torch.from_numpy(latent).float().to(device) for latent in latents]
        latents = torch.stack(latents)
        out = mnist(latents)
        outs = []
        for out_i in out:
            out_i = ((out_i[0] + 1) * 127.5).clamp(0,255).cpu().numpy()
            out_i = np.uint8(out_i)
            out_i = np.stack([out_i]*3, -1)
            outs.append(out_i)
        return outs
    
    return {
        'name': 'MNIST',
        'generator': mnist_generator,
        'latent_dim': 128,
        'fps': 20,
        'batch_size': 8,
        'strength': 0.75,
        'max_duration': 30,
        'use_peak': True
    }