File size: 1,191 Bytes
c93ed73 7862640 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from cgan import Generator
import gradio as gr
import torch
from torchvision.utils import make_grid
from torchvision.transforms.functional import to_pil_image
latent_dim = 100
n_classes = 10
img_size = 32
channels = 1
model = Generator()
model.load_state_dict(torch.load("generator1.pth", map_location=torch.device('cpu')))
model.eval()
def generate_image(class_idx):
with torch.no_grad():
# Generate random noise vector of latent_dim size
noise = torch.randn(1, latent_dim)
label = torch.tensor([int(class_idx)])
gen_img = model(noise, label).squeeze(0)
return to_pil_image(make_grid(gen_img, normalize=True))
# Create Gradio Interface
noise_input = gr.inputs.Slider(minimum=-1.0, maximum=1.0, default=0, step=0.1, label="Noise")
class_input = gr.inputs.Dropdown([str(i) for i in range(n_classes)], label="Class")
output_image = gr.outputs.Image('pil')
gr.Interface(
fn=generate_image,
inputs=[class_input],
outputs=output_image,
title="MNIST Generator",
description="Generate images of handwritten digits from the MNIST dataset using a GAN.",
theme="default",
layout="vertical",
live=True
).launch(debug=True) |