NoteDance commited on
Commit
5646c73
1 Parent(s): 6297344

Upload CLIP.py

Browse files
Files changed (1) hide show
  1. CLIP.py +350 -0
CLIP.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.layers import Dense,Conv2d,BatchNormalization,LayerNormalization,MultiHeadAttention
3
+ from tensorflow.keras.layers import ZeroPadding2D,AveragePooling2D,Identity
4
+ from tensorflow.keras import Model
5
+ import numpy as np
6
+ from typing import Tuple, Union
7
+
8
+
9
+ class Bottleneck(tf.keras.layers.Layer):
10
+ expansion = 4
11
+
12
+ def __init__(self, inplanes, planes, stride=1):
13
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
14
+ super(Bottleneck, self).__init__()
15
+ self.conv1 = Conv2d(planes, 1, use_bias=False)
16
+ self.bn1 = BatchNormalization()
17
+ self.relu1 = tf.nn.relu
18
+
19
+ self.zeropadding2d = ZeroPadding2D(padding=1)
20
+ self.conv2 = Conv2d(planes, 3, use_bias=False)
21
+ self.bn2 = BatchNormalization()
22
+ self.relu2 = tf.nn.relu
23
+
24
+ self.avgpool = AveragePooling2D(stride, stride, 'VALID') if stride > 1 else Identity()
25
+
26
+ self.conv3 = Conv2d(planes * self.expansion, 1, use_bias=False)
27
+ self.bn3 = BatchNormalization()
28
+ self.relu3 = tf.nn.relu
29
+
30
+ self.downsample = None
31
+ self.stride = stride
32
+
33
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
34
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
35
+ self.downsample = tf.keras.Sequential()
36
+ self.downsample.add(AveragePooling2D(stride, stride, 'VALID'))
37
+ self.downsample.add(Conv2d(planes * self.expansion, 1, strides=1, use_bias=False))
38
+ self.downsample.add(BatchNormalization())
39
+
40
+ def __call__(self, x):
41
+ identity = x
42
+
43
+ out = self.relu1(self.bn1(self.conv1(x)))
44
+ out = self.zeropadding2d(out)
45
+ out = self.relu2(self.bn2(self.conv2(out)))
46
+ out = self.avgpool(out)
47
+ out = self.bn3(self.conv3(out))
48
+
49
+ if self.downsample is not None:
50
+ identity = self.downsample(x)
51
+
52
+ out += identity
53
+ out = self.relu3(out)
54
+ return out
55
+
56
+
57
+ class AttentionPool2d:
58
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
59
+ self.positional_embedding = tf.Variable(tf.random.normal([spacial_dim ** 2 + 1, embed_dim]) / embed_dim ** 0.5)
60
+ self.k_proj = Dense(embed_dim)
61
+ self.q_proj = Dense(embed_dim)
62
+ self.v_proj = Dense(embed_dim)
63
+ self.c_proj = Dense(output_dim or embed_dim)
64
+ self.num_heads = num_heads
65
+
66
+ def __call__(self, x):
67
+ shape = x.shape
68
+ batch_size = shape[0]
69
+ height = shape[1]
70
+ width = shape[2]
71
+ channels = shape[3]
72
+ new_shape = (batch_size, height * width, channels)
73
+ x = tf.transpose(tf.reshape(x, new_shape), (1, 0, 2))
74
+ x = tf.concat([tf.reduce_mean(x, axis=0, keepdims=True), x], axis=0) # (HW+1)NC
75
+ x = x + tf.cast(self.positional_embedding[:, None, :], x.dtype) # (HW+1)NC
76
+ tgt_len, bsz, embed_dim = x.shape
77
+ query=self.q_proj(x[:1])
78
+ key=self.k_proj(x)
79
+ value=self.v_proj(x)
80
+ query = tf.reshape(query, [bsz, 1, self.num_heads, -1])
81
+ query = tf.transpose(query, [0, 2, 1, 3])
82
+ query = tf.multiply(query, 1.0 / tf.math.sqrt(float(embed_dim)))
83
+ key = tf.reshape(key, [bsz, tgt_len, self.num_heads, -1])
84
+ key = tf.transpose(key, [0, 2, 3, 1])
85
+ value = tf.reshape(value, [bsz, tgt_len, self.num_heads, -1])
86
+ value = tf.transpose(value, [0, 2, 1, 3])
87
+ qk = tf.matmul(query, key)
88
+ w = tf.nn.softmax(qk)
89
+ wv = tf.reshape(tf.transpose(tf.matmul(w, value), [0, 2, 1, 3]), [1, bsz, -1])
90
+ x = self.c_proj(wv)
91
+ return tf.squeeze(x, 0)
92
+
93
+
94
+ class ModifiedResNet:
95
+ """
96
+ A ResNet class that is similar to torchvision's but contains the following changes:
97
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99
+ - The final pooling layer is a QKV attention instead of an average pool
100
+ """
101
+
102
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103
+ self.output_dim = output_dim
104
+ self.input_resolution = input_resolution
105
+
106
+ # the 3-layer stem
107
+ self.zeropadding2d = ZeroPadding2D(padding=1)
108
+ self.conv1 = Conv2d(width // 2, kernel_size=3, strides=2, use_bias=False)
109
+ self.bn1 = BatchNormalization()
110
+ self.relu1 = tf.nn.relu
111
+ self.conv2 = Conv2d(width // 2, kernel_size=3, use_bias=False)
112
+ self.bn2 = BatchNormalization()
113
+ self.relu2 = tf.nn.relu
114
+ self.conv3 = Conv2d(width, kernel_size=3, use_bias=False)
115
+ self.bn3 = BatchNormalization()
116
+ self.relu3 = tf.nn.relu
117
+ self.avgpool = AveragePooling2D(2, 2, 'VALID')
118
+
119
+ # residual layers
120
+ self._inplanes = width # this is a *mutable* variable used during construction
121
+ self.layer1 = self._make_layer(width, layers[0])
122
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125
+
126
+ embed_dim = width * 32 # the ResNet feature dimension
127
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128
+
129
+ def _make_layer(self, planes, blocks, stride=1):
130
+ layers = tf.keras.Sequential()
131
+ layers.add(Bottleneck(self._inplanes, planes, stride))
132
+
133
+ self._inplanes = planes * Bottleneck.expansion
134
+ for _ in range(1, blocks):
135
+ layers.add(Bottleneck(self._inplanes, planes))
136
+
137
+ return layers
138
+
139
+ def __call__(self, x):
140
+ def stem(x):
141
+ x = self.zeropadding2d(x)
142
+ x = self.conv1(x)
143
+ x = self.relu1(self.bn1(x))
144
+ x = self.zeropadding2d(x)
145
+ x = self.conv2(x)
146
+ x = self.relu2(self.bn2(x))
147
+ x = self.zeropadding2d(x)
148
+ x = self.conv3(x)
149
+ x = self.relu3(self.bn3(x))
150
+ x = self.avgpool(x)
151
+ return x
152
+
153
+ x = stem(x)
154
+ x = self.layer1(x)
155
+ x = self.layer2(x)
156
+ x = self.layer3(x)
157
+ x = self.layer4(x)
158
+ x = self.attnpool(x)
159
+
160
+ return x
161
+
162
+
163
+ class LayerNorm:
164
+ """Subclass torch's LayerNorm to handle fp16."""
165
+ def __init__(self, input_size):
166
+ self.layer_norm = LayerNormalization()
167
+
168
+ def __call__(self, x):
169
+ orig_type = x.dtype
170
+ ret = self.layer_norm(tf.cast(x, tf.float32))
171
+ return tf.cast(ret, orig_type)
172
+
173
+
174
+ class QuickGELU(tf.keras.layers.Layer):
175
+ def __init__(self):
176
+ super(QuickGELU, self).__init__()
177
+
178
+ def __call__(self, x):
179
+ return x * tf.nn.sigmoid(1.702 * x)
180
+
181
+
182
+ class ResidualAttentionBlock(tf.keras.layers.Layer):
183
+ def __init__(self, d_model: int, n_head: int, attn_mask = None):
184
+ super(ResidualAttentionBlock, self).__init__()
185
+ self.attn = MultiHeadAttention(n_head, d_model)
186
+ self.ln_1 = LayerNorm(d_model)
187
+ self.mlp = tf.keras.Sequential()
188
+ self.mlp.add(Dense(d_model * 4))
189
+ self.mlp.add(QuickGELU())
190
+ self.mlp.add(Dense(d_model))
191
+ self.ln_2 = LayerNorm(d_model)
192
+ self.attn_mask = attn_mask
193
+
194
+ def attention(self, x):
195
+ self.attn_mask = tf.cast(self.attn_mask, x.dtype) if self.attn_mask is not None else None
196
+ return self.attn(x, x, attention_mask=self.attn_mask)[0]
197
+
198
+ def __call__(self, x):
199
+ x = x + self.attention(self.ln_1(x))
200
+ x = x + self.mlp(self.ln_2(x))
201
+ return x
202
+
203
+
204
+ class Transformer:
205
+ def __init__(self, width: int, layers: int, heads: int, attn_mask = None):
206
+ self.width = width
207
+ self.layers = layers
208
+ self.resblocks = tf.keras.Sequential()
209
+ for _ in range(layers):
210
+ self.resblocks.add(ResidualAttentionBlock(width, heads, attn_mask))
211
+
212
+ def __call__(self, x):
213
+ return self.resblocks(x)
214
+
215
+
216
+ class VisionTransformer:
217
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
218
+ self.input_resolution = input_resolution
219
+ self.output_dim = output_dim
220
+ self.conv1 = Conv2d(width, kernel_size=patch_size, strides=patch_size, use_bias=False)
221
+
222
+ scale = width ** -0.5
223
+ self.class_embedding = tf.Variable(scale * tf.random.normal([width]))
224
+ self.positional_embedding = tf.Variable(scale * tf.random.normal((input_resolution // patch_size) ** 2 + 1, width))
225
+ self.ln_pre = LayerNorm(width)
226
+
227
+ self.transformer = Transformer(width, layers, heads)
228
+
229
+ self.ln_post = LayerNorm(width)
230
+ self.proj = tf.Variable(scale * tf.random.normal(width, output_dim))
231
+
232
+ def __call__(self, x, train_flag=True):
233
+ x = self.conv1(x) # shape = [*, width, grid, grid]
234
+ x = tf.reshape(x, [x.shape[0], x.shape[1], -1]) # shape = [*, width, grid ** 2]
235
+ x = tf.transpose(x, (0, 2, 1)) # shape = [*, grid ** 2, width]
236
+ x = tf.concat([tf.cast(self.class_embedding, x.dtype) + tf.zeros([x.shape[0], 1, x.shape[-1]], dtype=x.dtype), x], axis=1) # shape = [*, grid ** 2 + 1, width]
237
+ x = x + tf.cast(self.positional_embedding, x.dtype)
238
+ x = self.ln_pre(x)
239
+
240
+ x = tf.transpose(x, (1, 0, 2)) # NLD -> LND
241
+ x = self.transformer(x)
242
+ x = tf.transpose(x, (1, 0, 2)) # LND -> NLD
243
+
244
+ x = self.ln_post(x[:, 0, :])
245
+
246
+ if self.proj is not None:
247
+ x = tf.matmul(x, self.proj)
248
+
249
+ return x
250
+
251
+
252
+ class CLIP(Model):
253
+ def __init__(self,
254
+ embed_dim: int,
255
+ # vision
256
+ image_resolution: int,
257
+ vision_layers: Union[Tuple[int, int, int, int], int],
258
+ vision_width: int,
259
+ vision_patch_size: int,
260
+ # text
261
+ context_length: int,
262
+ vocab_size: int,
263
+ transformer_width: int,
264
+ transformer_heads: int,
265
+ transformer_layers: int
266
+ ):
267
+ super(CLIP, self).__init__()
268
+
269
+ self.context_length = context_length
270
+
271
+ if isinstance(vision_layers, (tuple, list)):
272
+ vision_heads = vision_width * 32 // 64
273
+ self.visual = ModifiedResNet(
274
+ layers=vision_layers,
275
+ output_dim=embed_dim,
276
+ heads=vision_heads,
277
+ input_resolution=image_resolution,
278
+ width=vision_width
279
+ )
280
+ else:
281
+ vision_heads = vision_width // 64
282
+ self.visual = VisionTransformer(
283
+ input_resolution=image_resolution,
284
+ patch_size=vision_patch_size,
285
+ width=vision_width,
286
+ layers=vision_layers,
287
+ heads=vision_heads,
288
+ output_dim=embed_dim
289
+ )
290
+
291
+ self.transformer = Transformer(
292
+ width=transformer_width,
293
+ layers=transformer_layers,
294
+ heads=transformer_heads,
295
+ attn_mask=self.build_attention_mask()
296
+ )
297
+
298
+ self.vocab_size = vocab_size
299
+ self.token_embedding = tf.Variable(tf.random.normal((vocab_size, transformer_width),
300
+ stddev=0.02))
301
+ self.positional_embedding = tf.Variable(tf.random.normal((self.context_length, transformer_width),
302
+ stddev=0.01
303
+ ))
304
+ self.ln_final = LayerNorm(transformer_width)
305
+
306
+ self.text_projection = tf.Variable(tf.random.normal((transformer_width, embed_dim),
307
+ stddev=self.transformer.width ** -0.5,
308
+ ))
309
+ self.logit_scale = tf.Variable(tf.ones([]) * np.log(1 / 0.07))
310
+
311
+ def build_attention_mask(self):
312
+ mask = tf.ones((self.context_length, self.context_length))
313
+ mask = tf.linalg.band_part(mask, 0, -1) # zero out the upper diagonal
314
+ mask = mask * -1e9 # fill with -1e9
315
+ return mask
316
+
317
+ def encode_image(self, image):
318
+ return self.visual(image)
319
+
320
+ def encode_text(self, text):
321
+ x = tf.gather(self.token_embedding, text) # [batch_size, n_ctx, d_model]
322
+
323
+ x = x + self.positional_embedding
324
+ x = tf.transpose(x, (1, 0, 2)) # NLD -> LND
325
+ x = self.transformer(x)
326
+ x = tf.transpose(x, (1, 0, 2)) # LND -> NLD
327
+ x = self.ln_final(x)
328
+
329
+ # x.shape = [batch_size, n_ctx, transformer.width]
330
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
331
+ x = tf.matmul(tf.gather_nd(x, tf.stack([tf.range(x.shape[0], dtype='int32'),
332
+ tf.argmax(text, axis=-1, output_type='int32')], axis=1)), self.text_projection)
333
+
334
+ return x
335
+
336
+ def __call__(self, image, text):
337
+ image_features = self.encode_image(image)
338
+ text_features = self.encode_text(text)
339
+
340
+ # normalized features
341
+ image_features = image_features / tf.norm(image_features, axis=1, keepdims=True)
342
+ text_features = text_features / tf.norm(text_features, axis=1, keepdims=True)
343
+
344
+ # cosine similarity as logits
345
+ logit_scale = tf.math.exp(self.logit_scale)
346
+ logits_per_image = tf.matmul(logit_scale * image_features, tf.transpose(text_features))
347
+ logits_per_text = tf.transpose(logits_per_image)
348
+
349
+ # shape = [global_batch_size, global_batch_size]
350
+ return logits_per_image, logits_per_text