| import torch | |
| import matplotlib.pyplot as plt | |
| import torchvision | |
| import gradio as gr | |
| use_gpu = True if torch.cuda.is_available() else False | |
| model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', | |
| 'PGAN', model_name='celebAHQ-512', | |
| pretrained=True, useGPU=use_gpu) | |
| def pggan(num_images): | |
| noise, _ = model.buildNoiseData(int(num_images)) | |
| with torch.no_grad(): | |
| generated_images = model.test(noise) | |
| grid = torchvision.utils.make_grid(generated_images.clamp(min=-1, max=1), scale_each=True, normalize=True) | |
| plt.axis("off") | |
| plt.imshow(grid.permute(1, 2, 0).cpu().numpy()) | |
| return plt | |
| inputs = gr.inputs.Number(label="number of images") | |
| outputs = gr.outputs.Image(label="Output Image") | |
| title = "Progressive Growing of GANs" | |
| description = "Gradio demo for Progressive Growing of GANs (PGAN). To use it, simply add the number of images to generate or click on the examples. Read more below." | |
| article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1710.10196'>Progressive Growing of GANs for Improved Quality, Stability, and Variation</a> | <a href='https://github.com/facebookresearch/pytorch_GAN_zoo/blob/master/models/progressive_gan.py'>Github Repo</a></p>" | |
| examples = [ | |
| [1], | |
| [2], | |
| [3], | |
| [4] | |
| ] | |
| gr.Interface(pggan, inputs, outputs, title=title, description=description, article=article, analytics_enabled=False, examples=examples).launch(debug=True) |