akhaliq HF staff commited on
Commit
2fc77da
1 Parent(s): 1e61989

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -8
app.py CHANGED
@@ -1,19 +1,119 @@
 
1
  import os
2
  import gradio as gr
3
  from PIL import Image
4
 
5
- os.system("git clone https://github.com/AK391/projected_gan.git")
6
 
7
- os.chdir("projected_gan")
8
 
9
- os.mkdir("outputs")
10
 
 
11
 
12
- def inference(seeds):
13
- os.system("python gen_images.py --outdir=./outputs/ --seeds="+str(int(seeds))+" --network=https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/pokemon.pkl")
14
- seeds = int(seeds)
15
- image = Image.open(f"./outputs/seed{seeds:04d}.png")
16
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  title = "Projected GAN"
19
  description = "Gradio demo for Projected GANs Converge Faster, Pokemon. To use it, add seed, or click one of the examples to load them. Read more at the links below. We’re getting a lot of traffic from Hacker News so we added 10 cached examples"
 
1
+ import sys
2
  import os
3
  import gradio as gr
4
  from PIL import Image
5
 
6
+ os.system("git clone https://github.com/autonomousvision/projected_gan.git")
7
 
8
+ sys.path.append("projected_gan")
9
 
 
10
 
11
+ """Generate images using pretrained network pickle."""
12
 
13
+ import os
14
+ import re
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import click
18
+ import dnnlib
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+
23
+ import legacy
24
+
25
+ #----------------------------------------------------------------------------
26
+
27
+ def parse_range(s: Union[str, List]) -> List[int]:
28
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
29
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
30
+ '''
31
+ if isinstance(s, list): return s
32
+ ranges = []
33
+ range_re = re.compile(r'^(\d+)-(\d+)$')
34
+ for p in s.split(','):
35
+ m = range_re.match(p)
36
+ if m:
37
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
38
+ else:
39
+ ranges.append(int(p))
40
+ return ranges
41
+
42
+ #----------------------------------------------------------------------------
43
+
44
+ def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
45
+ '''Parse a floating point 2-vector of syntax 'a,b'.
46
+ Example:
47
+ '0,1' returns (0,1)
48
+ '''
49
+ if isinstance(s, tuple): return s
50
+ parts = s.split(',')
51
+ if len(parts) == 2:
52
+ return (float(parts[0]), float(parts[1]))
53
+ raise ValueError(f'cannot parse 2-vector {s}')
54
+
55
+ #----------------------------------------------------------------------------
56
+
57
+ def make_transform(translate: Tuple[float,float], angle: float):
58
+ m = np.eye(3)
59
+ s = np.sin(angle/360.0*np.pi*2)
60
+ c = np.cos(angle/360.0*np.pi*2)
61
+ m[0][0] = c
62
+ m[0][1] = s
63
+ m[0][2] = translate[0]
64
+ m[1][0] = -s
65
+ m[1][1] = c
66
+ m[1][2] = translate[1]
67
+ return m
68
+
69
+ #----------------------------------------------------------------------------
70
+
71
+
72
+ def generate_images(seeds):
73
+ """Generate images using pretrained network pickle.
74
+ Examples:
75
+ \b
76
+ # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
77
+ python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
78
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
79
+ \b
80
+ # Generate uncurated images with truncation using the MetFaces-U dataset
81
+ python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
82
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
83
+ """
84
+
85
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
86
+ with dnnlib.util.open_url('https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/pokemon.pkl') as f:
87
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
88
+
89
+
90
+ # Labels.
91
+ label = torch.zeros([1, G.c_dim], device=device)
92
+
93
+
94
+ # Generate images.
95
+ for seed_idx, seed in enumerate(seeds):
96
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
97
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device).float()
98
+
99
+ # Construct an inverse rotation/translation matrix and pass to the generator. The
100
+ # generator expects this matrix as an inverse to avoid potentially failing numerical
101
+ # operations in the network.
102
+ if hasattr(G.synthesis, 'input'):
103
+ m = make_transform('0,0', 0)
104
+ m = np.linalg.inv(m)
105
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
106
+
107
+ img = G(z, label, truncation_psi=1, noise_mode='const')
108
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
109
+ pilimg = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
110
+ return pilimg
111
+
112
+
113
+ def inference(seedin):
114
+ listseed = [int(seedin)]
115
+ output = generate_images(listseed)
116
+ return output
117
 
118
  title = "Projected GAN"
119
  description = "Gradio demo for Projected GANs Converge Faster, Pokemon. To use it, add seed, or click one of the examples to load them. Read more at the links below. We’re getting a lot of traffic from Hacker News so we added 10 cached examples"