klima7 commited on
Commit
4dd2c81
1 Parent(s): c501b3e

Add model files

Browse files
Files changed (3) hide show
  1. gaugan.py +178 -0
  2. weights/encoder.h5 +3 -0
  3. weights/generator.h5 +3 -0
gaugan.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import tensorflow_addons as tfa
4
+ import keras
5
+ from keras import Model, Sequential, initializers
6
+ from keras.layers import Layer, Conv2D, LeakyReLU, Dropout
7
+
8
+
9
+ class SPADE(Layer):
10
+ def __init__(self, filters: int, epsilon=1e-5, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.epsilon = epsilon
13
+ self.conv = Conv2D(128, 3, padding="same", activation="relu")
14
+ self.conv_gamma = Conv2D(filters, 3, padding="same")
15
+ self.conv_beta = Conv2D(filters, 3, padding="same")
16
+
17
+ def build(self, input_shape):
18
+ self.resize_shape = input_shape[1:3]
19
+
20
+ def call(self, input_tensor, raw_mask):
21
+ mask = tf.image.resize(raw_mask, self.resize_shape, method="nearest")
22
+ x = self.conv(mask)
23
+ gamma = self.conv_gamma(x)
24
+ beta = self.conv_beta(x)
25
+ mean, var = tf.nn.moments(input_tensor, axes=(0, 1, 2), keepdims=True)
26
+ std = tf.sqrt(var + self.epsilon)
27
+ normalized = (input_tensor - mean) / std
28
+ output = gamma * normalized + beta
29
+ return output
30
+
31
+ def get_config(self):
32
+ return {
33
+ "epsilon": self.epsilon,
34
+ "conv": self.conv,
35
+ "conv_gamma": self.conv_gamma,
36
+ "conv_beta": self.conv_beta
37
+ }
38
+
39
+
40
+ class ResBlock(Layer):
41
+ def __init__(self, filters: int, **kwargs):
42
+ super().__init__(**kwargs)
43
+ self.filters = filters
44
+
45
+ def build(self, input_shape):
46
+ input_filter = input_shape[-1]
47
+ self.spade_1 = SPADE(input_filter)
48
+ self.spade_2 = SPADE(self.filters)
49
+ self.conv_1 = Conv2D(self.filters, 3, padding="same")
50
+ self.conv_2 = Conv2D(self.filters, 3, padding="same")
51
+ self.leaky_relu = LeakyReLU(0.2)
52
+ self.learned_skip = False
53
+
54
+ if self.filters != input_filter:
55
+ self.learned_skip = True
56
+ self.spade_3 = SPADE(input_filter)
57
+ self.conv_3 = Conv2D(self.filters, 3, padding="same")
58
+
59
+ def call(self, input_tensor, mask):
60
+ x = self.spade_1(input_tensor, mask)
61
+ x = self.conv_1(self.leaky_relu(x))
62
+ x = self.spade_2(x, mask)
63
+ x = self.conv_2(self.leaky_relu(x))
64
+ skip = (
65
+ self.conv_3(self.leaky_relu(self.spade_3(input_tensor, mask)))
66
+ if self.learned_skip
67
+ else input_tensor
68
+ )
69
+ output = skip + x
70
+ return output
71
+
72
+ def get_config(self):
73
+ return {"filters": self.filters}
74
+
75
+
76
+ class Downsample(Layer):
77
+ def __init__(self,
78
+ channels: int,
79
+ kernels: int,
80
+ strides: int = 2,
81
+ apply_norm=True,
82
+ apply_activation=True,
83
+ apply_dropout=False,
84
+ **kwargs
85
+ ):
86
+ super().__init__(**kwargs)
87
+ self.channels = channels
88
+ self.kernels = kernels
89
+ self.strides = strides
90
+ self.apply_norm = apply_norm
91
+ self.apply_activation = apply_activation
92
+ self.apply_dropout = apply_dropout
93
+
94
+ def build(self, input_shape):
95
+ self.block = Sequential([
96
+ Conv2D(
97
+ self.channels,
98
+ self.kernels,
99
+ strides=self.strides,
100
+ padding="same",
101
+ use_bias=False,
102
+ kernel_initializer=initializers.GlorotNormal(),
103
+ )])
104
+ if self.apply_norm:
105
+ self.block.add(tfa.layers.InstanceNormalization())
106
+ if self.apply_activation:
107
+ self.block.add(LeakyReLU(0.2))
108
+ if self.apply_dropout:
109
+ self.block.add(Dropout(0.5))
110
+
111
+ def call(self, inputs):
112
+ return self.block(inputs)
113
+
114
+ def get_config(self):
115
+ return {
116
+ "channels": self.channels,
117
+ "kernels": self.kernels,
118
+ "strides": self.strides,
119
+ "apply_norm": self.apply_norm,
120
+ "apply_activation": self.apply_activation,
121
+ "apply_dropout": self.apply_dropout,
122
+ }
123
+
124
+
125
+ class GaussianSampler(Layer):
126
+ def __init__(self, latent_dim: int, **kwargs):
127
+ super().__init__(**kwargs)
128
+ self.latent_dim = latent_dim
129
+
130
+ def call(self, inputs):
131
+ means, variance = inputs
132
+ epsilon = tf.random.normal(
133
+ shape=(tf.shape(means)[0], self.latent_dim), mean=0.0, stddev=1.0
134
+ )
135
+ samples = means + tf.exp(0.5 * variance) * epsilon
136
+ return samples
137
+
138
+ def get_config(self):
139
+ return {"latent_dim": self.latent_dim}
140
+
141
+
142
+ class GauganPredictor():
143
+
144
+ CLASSES = (
145
+ 'unknown','wall', 'sky', 'tree', 'road', 'grass', 'earth',
146
+ 'mountain', 'plant', 'water', 'sea', 'field', 'fence', 'rock',
147
+ 'sand', 'path', 'river', 'flower', 'hill', 'palm', 'tower',
148
+ 'dirt', 'land', 'waterfall', 'lake'
149
+ )
150
+
151
+ def __init__(self, model_g_path: str, model_e_path: str = None) -> None:
152
+ custom_objects = {
153
+ 'ResBlock': ResBlock,
154
+ 'Downsample': Downsample,
155
+ }
156
+ if model_e_path is not None:
157
+ self.encoder: Model = keras.models.load_model(model_e_path, custom_objects=custom_objects)
158
+ self.sampler = GaussianSampler(256)
159
+ self.gen: Model = keras.models.load_model(
160
+ model_g_path, custom_objects=custom_objects)
161
+
162
+ def __call__(self, im: np.ndarray, z=None) -> np.ndarray:
163
+ if len(im.shape) == 3:
164
+ im = im[np.newaxis]
165
+ if z is None:
166
+ z = tf.random.normal((im.shape[0], 256))
167
+ tmp = self.gen.predict_on_batch([z, im])
168
+ x = np.array((tmp + 1) * 127.5, np.uint8)
169
+ return x
170
+
171
+ def predict_reference(self, im: np.ndarray, reference_im: np.ndarray) -> np.ndarray:
172
+ if len(im.shape) == 3:
173
+ im = im[np.newaxis]
174
+ reference_im = reference_im[np.newaxis]
175
+ mean, variance = self.encoder(reference_im)
176
+ z = self.sampler([mean, variance])
177
+ x = np.array((self.gen.predict_on_batch([z, im]) + 1) * 127.5, np.uint8)
178
+ return x
weights/encoder.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:163e382f0102f1c1356a178b41c1c2234b7cdf13a340116ec2eabda53a535078
3
+ size 82794536
weights/generator.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c89cec334ff70b6e49b25ddda6e10b2e0bb06d1d6c7cb7b134d6bf682f559607
3
+ size 342490352