File size: 1,191 Bytes
c93ed73
7862640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from cgan import Generator
import gradio as gr
import torch
from torchvision.utils import make_grid
from torchvision.transforms.functional import to_pil_image

latent_dim = 100
n_classes = 10
img_size = 32
channels = 1

model = Generator()
model.load_state_dict(torch.load("generator1.pth", map_location=torch.device('cpu')))
model.eval()


def generate_image(class_idx):
    with torch.no_grad():

        # Generate random noise vector of latent_dim size
        noise = torch.randn(1, latent_dim)
        label = torch.tensor([int(class_idx)])
        gen_img = model(noise, label).squeeze(0)
    return to_pil_image(make_grid(gen_img, normalize=True))


# Create Gradio Interface
noise_input = gr.inputs.Slider(minimum=-1.0, maximum=1.0, default=0, step=0.1, label="Noise")
class_input = gr.inputs.Dropdown([str(i) for i in range(n_classes)], label="Class")
output_image = gr.outputs.Image('pil')

gr.Interface(
    fn=generate_image,
    inputs=[class_input],
    outputs=output_image,
    title="MNIST Generator",
    description="Generate images of handwritten digits from the MNIST dataset using a GAN.",
    theme="default",
    layout="vertical",
    live=True
).launch(debug=True)