JayR7 commited on
Commit
bb5c8d5
1 Parent(s): 4943f16

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ rom utils import download_url
2
+ import argparse
3
+ import numpy as np
4
+ import PIL.Image
5
+ import dnnlib
6
+ import dnnlib.tflib as tflib
7
+ import re
8
+ import sys
9
+ from io import BytesIO
10
+ import IPython.display
11
+ from math import ceil
12
+ from PIL import Image, ImageDraw
13
+ import os
14
+ import pickle
15
+ from utils import log_progress, imshow, create_image_grid, show_animation
16
+ import imageio
17
+ import glob
18
+ import gdown
19
+ import gradio as gr
20
+
21
+ class Rasm:
22
+
23
+ def __init__(self, mode = 'calligraphy'):
24
+
25
+ if mode == 'calligraphy':
26
+ url = 'https://drive.google.com/uc?id=138fdURGxdkOwZq7IWvnrGLcfo5VI8O1R'
27
+
28
+ else:
29
+ url = 'https://drive.google.com/uc?id=13h-alXGI0hbNOJy1qbmeoroXZSPBHEG2'
30
+
31
+ output = 'model.pkl'
32
+ print('Downloading networks from "%s"...' %url)
33
+ gdown.download(url, output, quiet=False)
34
+ dnnlib.tflib.init_tf()
35
+ with dnnlib.util.open_url(output) as fp:
36
+ self._G, self._D, self.Gs = pickle.load(fp)
37
+ self.noise_vars = [var for name, var in self.Gs.components.synthesis.vars.items() if name.startswith('noise')]
38
+
39
+ # Generates a list of images, based on a list of latent vectors (Z), and a list (or a single constant) of truncation_psi's.
40
+ def generate_images_in_w_space(self, dlatents, truncation_psi):
41
+ Gs_kwargs = dnnlib.EasyDict()
42
+ Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
43
+ Gs_kwargs.randomize_noise = False
44
+ Gs_kwargs.truncation_psi = truncation_psi
45
+ # dlatent_avg = self.Gs.get_var('dlatent_avg') # [component]
46
+
47
+ imgs = []
48
+ for _, dlatent in log_progress(enumerate(dlatents), name = "Generating images"):
49
+ #row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(truncation_psi, [-1, 1, 1]) + dlatent_avg
50
+ # dl = (dlatent-dlatent_avg)*truncation_psi + dlatent_avg
51
+ row_images = self.Gs.components.synthesis.run(dlatent, **Gs_kwargs)
52
+ imgs.append(PIL.Image.fromarray(row_images[0], 'RGB'))
53
+ return imgs
54
+
55
+ def generate_images(self, zs, truncation_psi, class_idx = None):
56
+ Gs_kwargs = dnnlib.EasyDict()
57
+ Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
58
+ Gs_kwargs.randomize_noise = False
59
+ if not isinstance(truncation_psi, list):
60
+ truncation_psi = [truncation_psi] * len(zs)
61
+
62
+ imgs = []
63
+ label = np.zeros([1] + self.Gs.input_shapes[1][1:])
64
+ if class_idx is not None:
65
+ label[:, class_idx] = 1
66
+ else:
67
+ label = None
68
+ for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = "Generating images"):
69
+ Gs_kwargs.truncation_psi = truncation_psi[z_idx]
70
+ noise_rnd = np.random.RandomState(1) # fix noise
71
+ tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in self.noise_vars}) # [height, width]
72
+ images = self.Gs.run(z, label, **Gs_kwargs) # [minibatch, height, width, channel]
73
+ imgs.append(PIL.Image.fromarray(images[0], 'RGB'))
74
+ return imgs
75
+
76
+ def generate_from_zs(self, zs, truncation_psi = 0.5):
77
+ Gs_kwargs = dnnlib.EasyDict()
78
+ Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
79
+ Gs_kwargs.randomize_noise = False
80
+ if not isinstance(truncation_psi, list):
81
+ truncation_psi = [truncation_psi] * len(zs)
82
+
83
+ for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = "Generating images"):
84
+ Gs_kwargs.truncation_psi = truncation_psi[z_idx]
85
+ noise_rnd = np.random.RandomState(1) # fix noise
86
+ tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in self.noise_vars}) # [height, width]
87
+ images = self.Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel]
88
+ img = PIL.Image.fromarray(images[0], 'RGB')
89
+ imshow(img)
90
+
91
+ def generate_random_zs(self, size):
92
+ seeds = np.random.randint(2**32, size=size)
93
+ zs = []
94
+ for _, seed in enumerate(seeds):
95
+ rnd = np.random.RandomState(seed)
96
+ z = rnd.randn(1, *self.Gs.input_shape[1:]) # [minibatch, component]
97
+ zs.append(z)
98
+ return zs
99
+
100
+
101
+ def generate_zs_from_seeds(self, seeds):
102
+ zs = []
103
+ for _, seed in enumerate(seeds):
104
+ rnd = np.random.RandomState(seed)
105
+ z = rnd.randn(1, *self.Gs.input_shape[1:]) # [minibatch, component]
106
+ zs.append(z)
107
+ return zs
108
+
109
+ # Generates a list of images, based on a list of seed for latent vectors (Z), and a list (or a single constant) of truncation_psi's.
110
+ def generate_images_from_seeds(self, seeds, truncation_psi):
111
+ ima = self.generate_images(self.generate_zs_from_seeds(seeds), truncation_psi)[0]
112
+ return ima, imshow(ima)
113
+
114
+ def generate_randomly(self, truncation_psi = 0.5):
115
+ ima, dis = self.generate_images_from_seeds(np.random.randint(4294967295, size=1), truncation_psi=truncation_psi)
116
+ return ima, dis
117
+
118
+ def generate_grid(self, truncation_psi = 0.7):
119
+ seeds = np.random.randint((2**32 - 1), size=9)
120
+ return create_image_grid(self.generate_images(self.generate_zs_from_seeds(seeds), truncation_psi), 0.7 , 3)
121
+
122
+ def generate_animation(self, size = 9, steps = 10, trunc_psi = 0.5):
123
+ seeds = list(np.random.randint((2**32) - 1, size=size))
124
+ seeds = seeds + [seeds[0]]
125
+ zs = self.generate_zs_from_seeds(seeds)
126
+
127
+ imgs = self.generate_images(self.interpolate(zs, steps = steps), trunc_psi)
128
+ movie_name = 'animation.mp4'
129
+ with imageio.get_writer(movie_name, mode='I') as writer:
130
+ for image in log_progress(list(imgs), name = "Creating animation"):
131
+ writer.append_data(np.array(image))
132
+ return show_animation(movie_name)
133
+
134
+ def convertZtoW(self, latent, truncation_psi=0.7, truncation_cutoff=9):
135
+ dlatent = self.Gs.components.mapping.run(latent, None) # [seed, layer, component]
136
+ dlatent_avg = self.Gs.get_var('dlatent_avg') # [component]
137
+ for i in range(truncation_cutoff):
138
+ dlatent[0][i] = (dlatent[0][i]-dlatent_avg)*truncation_psi + dlatent_avg
139
+
140
+ return dlatent
141
+
142
+ def interpolate(self, zs, steps = 10):
143
+ out = []
144
+ for i in range(len(zs)-1):
145
+ for index in range(steps):
146
+ fraction = index/float(steps)
147
+ out.append(zs[i+1]*fraction + zs[i]*(1-fraction))
148
+ return out
149
+
150
+
151
+ #-------------------- Rasm Demo--------------------------
152
+
153
+ def model(mode, output):
154
+ model=rasm.Rasm(mode=mode)
155
+ if output=='Generate Art Randomly':
156
+ ima,res= model.generate_randomly()
157
+ elif output=='Generate Art Grid':
158
+ ima = model.generate_grid()
159
+ elif output=='Generate Art Animation':
160
+ ima = model.generate_animation(size = 2, steps = 20)
161
+ return ima
162
+
163
+ imageout=gr.outputs.Image(model,
164
+ [
165
+ gr.Radio(["calligraphy", "mosaics"],label="Type of Arbic Art"),
166
+ gr.Radio(["Generate Art Randomly", "Generate Art Grid", "Generate Art Animation"],label="How do you prefer the output visualization" ),
167
+ ],
168
+ outputs=imageout
169
+ )
170
+ demo.launch()