dog_gans / app.py
gianTheo's picture
Update app.py
784f6ea verified
import torchvision.utils as vutils
import torchvision.transforms as T
from PIL import Image
import gradio as gr
import torch
from model import generator
def generate_images(num_images, z_dim=100):
# Generate batch of latent vectors
noise = torch.randn(num_images, z_dim, 1, 1)
# Generate fake image batch with G
generator.eval() # Set the generator to evaluation mode
with torch.no_grad():
fake_images = generator(noise).detach().cpu()
# Plot the fake images
img = vutils.make_grid(fake_images, padding=2, normalize=True).permute(1, 2, 0)
img = img.permute(2,0,1)
transform = T.ToPILImage()
# convert the tensor to PIL image using above transform
img = transform(img)
return img
batch_size = 16
z = torch.randn(batch_size, 100, 1, 1)
fake_images = generator(z)
# Create a Gradio input component for a positive integer
#inp = gr.Number(value=0, minimum=0, label="Enter a positive integer")
# Create a Gradio output component for an image
out = gr.Image(type="pil",label = "Generated Dogs")
inp = gr.Number(value=0, minimum=0, label="Enter number of Dogs to Generate (positive integer)")
demo = gr.Interface(fn=generate_images,inputs = inp, outputs=out, allow_flagging="never")
# Launch the Gradio interface
demo.launch(share=True)