i72sijia commited on
Commit
e8feed3
1 Parent(s): 27605da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -149
app.py CHANGED
@@ -1,166 +1,80 @@
1
- import sys
2
- import pickle
3
  import os
4
  import numpy as np
5
- import PIL.Image
6
- import IPython.display
7
- from IPython.display import Image
8
- import matplotlib.pyplot as plt
9
-
10
- import gradio as gr
11
-
12
- sys.path.insert(0, "/StyleGAN2-GANbanales")
13
-
14
- import dnnlib
15
- import dnnlib.tflib as tflib
16
 
17
- ##############################################################################
18
- # Generation functions
19
 
20
- def seed2vec(Gs, seed):
21
- rnd = np.random.RandomState(seed)
22
- return rnd.randn(1, *Gs.input_shape[1:])
23
 
24
- def init_random_state(Gs, seed):
25
- rnd = np.random.RandomState(seed)
26
- noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
27
- tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
28
 
29
- def generate_image(Gs, z, truncation_psi, prefix="image", save=False, show=False):
30
- # Render images for dlatents initialized from random seeds.
31
- Gs_kwargs = {
32
- 'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
33
- 'randomize_noise': False
34
- }
35
- if truncation_psi is not None:
36
- Gs_kwargs['truncation_psi'] = truncation_psi
37
 
38
- label = np.zeros([1] + Gs.input_shapes[1][1:])
39
- images = Gs.run(z, label, **Gs_kwargs) # [minibatch, height, width, channel]
 
 
 
40
 
41
- if save == True:
42
- path = f"{prefix}.png"
43
- PIL.Image.fromarray(images[0], 'RGB').save(path)
44
 
45
- if show == True:
46
- return images[0]
47
-
48
- ##############################################################################
49
- # Function concatenate
50
 
51
- def concatenate(img_array):
52
 
53
- zeros = np.zeros([256,256,3], dtype=np.uint8)
54
- zeros.fill(255)
55
- white_img = zeros
56
 
57
- # 1 - 2 images
58
- if len(img_array) <= 2:
59
- row_img = img_array[0]
60
- for i in img_array[1:]:
61
- row_img = np.hstack((row_img, i))
62
-
63
- final_img = row_img
64
-
65
- # 3 - 4 images
66
- elif len(img_array) >= 3 and len(img_array) <= 4:
67
- row1_img = img_array[0]
68
- for i in img_array[1:2]:
69
- row1_img = np.hstack((row1_img, i))
70
 
71
- row2_img = img_array[2]
72
- for i in img_array[3:]:
73
- row2_img = np.hstack((row2_img, i))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- for i in range(4-len(img_array)):
76
- row2_img = np.hstack((row2_img, white_img))
77
 
78
- final_img = np.vstack((row1_img, row2_img))
79
 
80
- # 5 - 6 images
81
- elif len(img_array) >= 4 and len(img_array) <= 6:
82
- row1_img = img_array[0]
83
- for i in img_array[1:3]:
84
- row1_img = cv2.hconcat([row1_img, i])
85
-
86
- row2_img = img_array[3]
87
- for i in img_array[4:]:
88
- row2_img = cv2.hconcat([row2_img, i])
89
-
90
- for i in range(6-len(img_array)):
91
- row2_img = cv2.hconcat([row2_img, white_img])
92
-
93
- final_img = cv2.vconcat([row1_img, row2_img])
94
 
95
- # 7 - 9 images
96
- elif len(img_array) >= 7:
97
- row1_img = img_array[0]
98
- for i in img_array[1:3]:
99
- row1_img = cv2.hconcat([row1_img, i])
100
-
101
- row2_img = img_array[3]
102
- for i in img_array[4:6]:
103
- row2_img = cv2.hconcat([row2_img, i])
104
-
105
- row3_img = img_array[6]
106
- for i in img_array[7:9]:
107
- row3_img = cv2.hconcat([row3_img, i])
108
-
109
- for i in range(9-len(img_array)):
110
- row3_img = cv2.hconcat([row3_img, white_img])
111
-
112
- final_img = cv2.vconcat([row1_img, row2_img])
113
- final_img = cv2.vconcat([final_img, row3_img])
114
-
115
- return final_img
116
-
117
- ##############################################################################
118
- # Function initiate
119
- def initiate(seed, n_imgs, text):
120
-
121
- pkl_file = "networks/experimento_2.pkl"
122
- tflib.init_tf()
123
-
124
- with open(pkl_file, 'rb') as pickle_file:
125
- _G, _D, Gs = pickle.load(pickle_file)
126
-
127
- img_array = []
128
- first_seed = seed
129
-
130
- for i in range(seed, seed+n_imgs):
131
- init_random_state(Gs, 10)
132
- z = seed2vec(Gs, seed)
133
- img = generate_image(Gs, z, 1.0, show=True)
134
-
135
- img_array.append(img)
136
- seed+=1
137
-
138
- final_img = concatenate(img_array)
139
-
140
- return final_img, "Imágenes generadas"
141
-
142
- ##############################################################################
143
- # Gradio code
144
-
145
- iface = gr.Interface(
146
- fn=initiate,
147
- inputs=[gr.inputs.Slider(0, 99999999, "image"), gr.inputs.Slider(1, 9, "images"), "text"],
148
- outputs=["image", "text"],
149
- examples=[
150
- [40, 1, "Edificios al anochecer"],
151
- [37, 1, "Fuente de día"],
152
- [426, 1, "Edificios con cielo oscuro"],
153
- [397, 1, "Edificios de día"],
154
- [395, 1, "Edificios desde anfiteatro"],
155
- [281, 1, "Edificios con luces encendidas"],
156
- [230, 1, "Edificios con luces encendidas y vegetación"],
157
- [221, 1, "Edificios con vegetación"],
158
- [214, 1, "Edificios al atardecer con luces encendidas"],
159
- [198, 1, "Edificio al anochecer con luces en el pasillo"]
160
- ],
161
- title="GANbanales",
162
- description="Una GAN para generar imágenes del campus universitario de Rabanales, Córdoba."
163
- )
164
-
165
- if __name__ == "__main__":
166
- app, local_url, share_url = iface.launch(debug=True)
 
1
+ import gradio as gr
 
2
  import os
3
  import numpy as np
4
+ import torch
5
+ import pickle
6
+ import types
 
 
 
 
 
 
 
 
7
 
8
+ from huggingface_hub import hf_hub_url, cached_download
 
9
 
10
+ TOKEN = os.environ['TOKEN']
 
 
11
 
12
+ with open(cached_download('network.pkl', 'rb') as f:
13
+ G = pickle.load(f)['G_ema']# torch.nn.Module
 
 
14
 
15
+ device = torch.device("cpu")
 
 
 
 
 
 
 
16
 
17
+ if torch.cuda.is_available():
18
+ device = torch.device("cuda")
19
+ G = G.to(device)
20
+ else:
21
+ _old_forward = G.forward
22
 
23
+ def _new_forward(self, *args, **kwargs):
24
+ kwargs["force_fp32"] = True
25
+ return _old_forward(*args, **kwargs)
26
 
27
+ G.forward = types.MethodType(_new_forward, G)
 
 
 
 
28
 
29
+ _old_synthesis_forward = G.synthesis.forward
30
 
31
+ def _new_synthesis_forward(self, *args, **kwargs):
32
+ kwargs["force_fp32"] = True
33
+ return _old_synthesis_forward(*args, **kwargs)
34
 
35
+ G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ ####################################################################
38
+ # Image generation
39
+
40
+ def generate(num_images, interpolate):
41
+ if interpolate:
42
+ z1 = torch.randn([1, G.z_dim])# latent codes
43
+ z2 = torch.randn([1, G.z_dim])# latent codes
44
+ zs = torch.cat([z1 + (z2 - z1) * i / (num_images-1) for i in range(num_images)], 0)
45
+ else:
46
+ zs = torch.randn([num_images, G.z_dim])# latent codes
47
+ with torch.no_grad():
48
+ zs = zs.to(device)
49
+ img = G(zs, None, force_fp32=True, noise_mode='const')
50
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
51
+ return img.cpu().numpy()
52
+
53
+ ####################################################################
54
+ # Graphical User Interface
55
+
56
+ demo = gr.Blocks()
57
+
58
+ def infer(num_images, interpolate):
59
+ img = generate(round(num_images), interpolate)
60
+ imgs = list(img)
61
+ return imgs
62
+
63
+ with demo:
64
+ gr.Markdown(
65
+ """
66
+ # EmojiGAN
67
+ Generate Emojis with AI (StyleGAN2-ADA). Made by [mfrashad](https://github.com/mfrashad)
68
+ """)
69
+ images_num = gr.inputs.Slider(default=1, label="Num Images", minimum=1, maximum=16, step=1)
70
+ interpolate = gr.inputs.Checkbox(default=False, label="Interpolate")
71
+ submit = gr.Button("Generate")
72
 
 
 
73
 
74
+ out = gr.Gallery()
75
 
76
+ submit.click(fn=infer,
77
+ inputs=[images_num, interpolate],
78
+ outputs=out)
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ demo.launch()