File size: 1,485 Bytes
5d2e45e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import matplotlib.pyplot as plt
import torchvision
import gradio as gr
use_gpu = True if torch.cuda.is_available() else False

model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub',
                       'PGAN', model_name='celebAHQ-512',
                       pretrained=True, useGPU=use_gpu)



def pggan(num_images):
  noise, _ = model.buildNoiseData(int(num_images))
  with torch.no_grad():
      generated_images = model.test(noise)

  grid = torchvision.utils.make_grid(generated_images.clamp(min=-1, max=1), scale_each=True, normalize=True)
  plt.axis("off")
  plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
  return plt


inputs = gr.inputs.Number(label="number of images")
outputs = gr.outputs.Image(label="Output Image")

title = "Progressive Growing of GANs"
description = "Gradio demo for Progressive Growing of GANs (PGAN). To use it, simply add the number of images to generate or click on the examples. Read more below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1710.10196'>Progressive Growing of GANs for Improved Quality, Stability, and Variation</a> | <a href='https://github.com/facebookresearch/pytorch_GAN_zoo/blob/master/models/progressive_gan.py'>Github Repo</a></p>"
examples = [
            [1], 
            [2],
            [3],
            [4]
]


gr.Interface(pggan, inputs, outputs, title=title, description=description, article=article, analytics_enabled=False, examples=examples).launch(debug=True)