|
import torch |
|
import matplotlib.pyplot as plt |
|
import torchvision |
|
import gradio as gr |
|
use_gpu = True if torch.cuda.is_available() else False |
|
|
|
model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', |
|
'PGAN', model_name='celebAHQ-512', |
|
pretrained=True, useGPU=use_gpu) |
|
|
|
|
|
|
|
def pggan(num_images): |
|
noise, _ = model.buildNoiseData(int(num_images)) |
|
with torch.no_grad(): |
|
generated_images = model.test(noise) |
|
|
|
grid = torchvision.utils.make_grid(generated_images.clamp(min=-1, max=1), scale_each=True, normalize=True) |
|
plt.axis("off") |
|
plt.imshow(grid.permute(1, 2, 0).cpu().numpy()) |
|
return plt |
|
|
|
|
|
inputs = gr.inputs.Number(label="number of images") |
|
outputs = gr.outputs.Image(label="Output Image") |
|
|
|
title = "Progressive Growing of GANs" |
|
description = "Gradio demo for Progressive Growing of GANs (PGAN). To use it, simply add the number of images to generate or click on the examples. Read more below." |
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1710.10196'>Progressive Growing of GANs for Improved Quality, Stability, and Variation</a> | <a href='https://github.com/facebookresearch/pytorch_GAN_zoo/blob/master/models/progressive_gan.py'>Github Repo</a></p>" |
|
examples = [ |
|
[1], |
|
[2], |
|
[3], |
|
[4] |
|
] |
|
|
|
|
|
gr.Interface(pggan, inputs, outputs, title=title, description=description, article=article, analytics_enabled=False, examples=examples).launch(debug=True) |