Upload CLIP.py
Browse files
CLIP.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras.layers import Dense,Conv2d,BatchNormalization,LayerNormalization,MultiHeadAttention
|
3 |
+
from tensorflow.keras.layers import ZeroPadding2D,AveragePooling2D,Identity
|
4 |
+
from tensorflow.keras import Model
|
5 |
+
import numpy as np
|
6 |
+
from typing import Tuple, Union
|
7 |
+
|
8 |
+
|
9 |
+
class Bottleneck(tf.keras.layers.Layer):
|
10 |
+
expansion = 4
|
11 |
+
|
12 |
+
def __init__(self, inplanes, planes, stride=1):
|
13 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
14 |
+
super(Bottleneck, self).__init__()
|
15 |
+
self.conv1 = Conv2d(planes, 1, use_bias=False)
|
16 |
+
self.bn1 = BatchNormalization()
|
17 |
+
self.relu1 = tf.nn.relu
|
18 |
+
|
19 |
+
self.zeropadding2d = ZeroPadding2D(padding=1)
|
20 |
+
self.conv2 = Conv2d(planes, 3, use_bias=False)
|
21 |
+
self.bn2 = BatchNormalization()
|
22 |
+
self.relu2 = tf.nn.relu
|
23 |
+
|
24 |
+
self.avgpool = AveragePooling2D(stride, stride, 'VALID') if stride > 1 else Identity()
|
25 |
+
|
26 |
+
self.conv3 = Conv2d(planes * self.expansion, 1, use_bias=False)
|
27 |
+
self.bn3 = BatchNormalization()
|
28 |
+
self.relu3 = tf.nn.relu
|
29 |
+
|
30 |
+
self.downsample = None
|
31 |
+
self.stride = stride
|
32 |
+
|
33 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
34 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
35 |
+
self.downsample = tf.keras.Sequential()
|
36 |
+
self.downsample.add(AveragePooling2D(stride, stride, 'VALID'))
|
37 |
+
self.downsample.add(Conv2d(planes * self.expansion, 1, strides=1, use_bias=False))
|
38 |
+
self.downsample.add(BatchNormalization())
|
39 |
+
|
40 |
+
def __call__(self, x):
|
41 |
+
identity = x
|
42 |
+
|
43 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
44 |
+
out = self.zeropadding2d(out)
|
45 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
46 |
+
out = self.avgpool(out)
|
47 |
+
out = self.bn3(self.conv3(out))
|
48 |
+
|
49 |
+
if self.downsample is not None:
|
50 |
+
identity = self.downsample(x)
|
51 |
+
|
52 |
+
out += identity
|
53 |
+
out = self.relu3(out)
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
class AttentionPool2d:
|
58 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
59 |
+
self.positional_embedding = tf.Variable(tf.random.normal([spacial_dim ** 2 + 1, embed_dim]) / embed_dim ** 0.5)
|
60 |
+
self.k_proj = Dense(embed_dim)
|
61 |
+
self.q_proj = Dense(embed_dim)
|
62 |
+
self.v_proj = Dense(embed_dim)
|
63 |
+
self.c_proj = Dense(output_dim or embed_dim)
|
64 |
+
self.num_heads = num_heads
|
65 |
+
|
66 |
+
def __call__(self, x):
|
67 |
+
shape = x.shape
|
68 |
+
batch_size = shape[0]
|
69 |
+
height = shape[1]
|
70 |
+
width = shape[2]
|
71 |
+
channels = shape[3]
|
72 |
+
new_shape = (batch_size, height * width, channels)
|
73 |
+
x = tf.transpose(tf.reshape(x, new_shape), (1, 0, 2))
|
74 |
+
x = tf.concat([tf.reduce_mean(x, axis=0, keepdims=True), x], axis=0) # (HW+1)NC
|
75 |
+
x = x + tf.cast(self.positional_embedding[:, None, :], x.dtype) # (HW+1)NC
|
76 |
+
tgt_len, bsz, embed_dim = x.shape
|
77 |
+
query=self.q_proj(x[:1])
|
78 |
+
key=self.k_proj(x)
|
79 |
+
value=self.v_proj(x)
|
80 |
+
query = tf.reshape(query, [bsz, 1, self.num_heads, -1])
|
81 |
+
query = tf.transpose(query, [0, 2, 1, 3])
|
82 |
+
query = tf.multiply(query, 1.0 / tf.math.sqrt(float(embed_dim)))
|
83 |
+
key = tf.reshape(key, [bsz, tgt_len, self.num_heads, -1])
|
84 |
+
key = tf.transpose(key, [0, 2, 3, 1])
|
85 |
+
value = tf.reshape(value, [bsz, tgt_len, self.num_heads, -1])
|
86 |
+
value = tf.transpose(value, [0, 2, 1, 3])
|
87 |
+
qk = tf.matmul(query, key)
|
88 |
+
w = tf.nn.softmax(qk)
|
89 |
+
wv = tf.reshape(tf.transpose(tf.matmul(w, value), [0, 2, 1, 3]), [1, bsz, -1])
|
90 |
+
x = self.c_proj(wv)
|
91 |
+
return tf.squeeze(x, 0)
|
92 |
+
|
93 |
+
|
94 |
+
class ModifiedResNet:
|
95 |
+
"""
|
96 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
97 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
98 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
99 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
103 |
+
self.output_dim = output_dim
|
104 |
+
self.input_resolution = input_resolution
|
105 |
+
|
106 |
+
# the 3-layer stem
|
107 |
+
self.zeropadding2d = ZeroPadding2D(padding=1)
|
108 |
+
self.conv1 = Conv2d(width // 2, kernel_size=3, strides=2, use_bias=False)
|
109 |
+
self.bn1 = BatchNormalization()
|
110 |
+
self.relu1 = tf.nn.relu
|
111 |
+
self.conv2 = Conv2d(width // 2, kernel_size=3, use_bias=False)
|
112 |
+
self.bn2 = BatchNormalization()
|
113 |
+
self.relu2 = tf.nn.relu
|
114 |
+
self.conv3 = Conv2d(width, kernel_size=3, use_bias=False)
|
115 |
+
self.bn3 = BatchNormalization()
|
116 |
+
self.relu3 = tf.nn.relu
|
117 |
+
self.avgpool = AveragePooling2D(2, 2, 'VALID')
|
118 |
+
|
119 |
+
# residual layers
|
120 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
121 |
+
self.layer1 = self._make_layer(width, layers[0])
|
122 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
123 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
124 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
125 |
+
|
126 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
127 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
128 |
+
|
129 |
+
def _make_layer(self, planes, blocks, stride=1):
|
130 |
+
layers = tf.keras.Sequential()
|
131 |
+
layers.add(Bottleneck(self._inplanes, planes, stride))
|
132 |
+
|
133 |
+
self._inplanes = planes * Bottleneck.expansion
|
134 |
+
for _ in range(1, blocks):
|
135 |
+
layers.add(Bottleneck(self._inplanes, planes))
|
136 |
+
|
137 |
+
return layers
|
138 |
+
|
139 |
+
def __call__(self, x):
|
140 |
+
def stem(x):
|
141 |
+
x = self.zeropadding2d(x)
|
142 |
+
x = self.conv1(x)
|
143 |
+
x = self.relu1(self.bn1(x))
|
144 |
+
x = self.zeropadding2d(x)
|
145 |
+
x = self.conv2(x)
|
146 |
+
x = self.relu2(self.bn2(x))
|
147 |
+
x = self.zeropadding2d(x)
|
148 |
+
x = self.conv3(x)
|
149 |
+
x = self.relu3(self.bn3(x))
|
150 |
+
x = self.avgpool(x)
|
151 |
+
return x
|
152 |
+
|
153 |
+
x = stem(x)
|
154 |
+
x = self.layer1(x)
|
155 |
+
x = self.layer2(x)
|
156 |
+
x = self.layer3(x)
|
157 |
+
x = self.layer4(x)
|
158 |
+
x = self.attnpool(x)
|
159 |
+
|
160 |
+
return x
|
161 |
+
|
162 |
+
|
163 |
+
class LayerNorm:
|
164 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
165 |
+
def __init__(self, input_size):
|
166 |
+
self.layer_norm = LayerNormalization()
|
167 |
+
|
168 |
+
def __call__(self, x):
|
169 |
+
orig_type = x.dtype
|
170 |
+
ret = self.layer_norm(tf.cast(x, tf.float32))
|
171 |
+
return tf.cast(ret, orig_type)
|
172 |
+
|
173 |
+
|
174 |
+
class QuickGELU(tf.keras.layers.Layer):
|
175 |
+
def __init__(self):
|
176 |
+
super(QuickGELU, self).__init__()
|
177 |
+
|
178 |
+
def __call__(self, x):
|
179 |
+
return x * tf.nn.sigmoid(1.702 * x)
|
180 |
+
|
181 |
+
|
182 |
+
class ResidualAttentionBlock(tf.keras.layers.Layer):
|
183 |
+
def __init__(self, d_model: int, n_head: int, attn_mask = None):
|
184 |
+
super(ResidualAttentionBlock, self).__init__()
|
185 |
+
self.attn = MultiHeadAttention(n_head, d_model)
|
186 |
+
self.ln_1 = LayerNorm(d_model)
|
187 |
+
self.mlp = tf.keras.Sequential()
|
188 |
+
self.mlp.add(Dense(d_model * 4))
|
189 |
+
self.mlp.add(QuickGELU())
|
190 |
+
self.mlp.add(Dense(d_model))
|
191 |
+
self.ln_2 = LayerNorm(d_model)
|
192 |
+
self.attn_mask = attn_mask
|
193 |
+
|
194 |
+
def attention(self, x):
|
195 |
+
self.attn_mask = tf.cast(self.attn_mask, x.dtype) if self.attn_mask is not None else None
|
196 |
+
return self.attn(x, x, attention_mask=self.attn_mask)[0]
|
197 |
+
|
198 |
+
def __call__(self, x):
|
199 |
+
x = x + self.attention(self.ln_1(x))
|
200 |
+
x = x + self.mlp(self.ln_2(x))
|
201 |
+
return x
|
202 |
+
|
203 |
+
|
204 |
+
class Transformer:
|
205 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask = None):
|
206 |
+
self.width = width
|
207 |
+
self.layers = layers
|
208 |
+
self.resblocks = tf.keras.Sequential()
|
209 |
+
for _ in range(layers):
|
210 |
+
self.resblocks.add(ResidualAttentionBlock(width, heads, attn_mask))
|
211 |
+
|
212 |
+
def __call__(self, x):
|
213 |
+
return self.resblocks(x)
|
214 |
+
|
215 |
+
|
216 |
+
class VisionTransformer:
|
217 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
218 |
+
self.input_resolution = input_resolution
|
219 |
+
self.output_dim = output_dim
|
220 |
+
self.conv1 = Conv2d(width, kernel_size=patch_size, strides=patch_size, use_bias=False)
|
221 |
+
|
222 |
+
scale = width ** -0.5
|
223 |
+
self.class_embedding = tf.Variable(scale * tf.random.normal([width]))
|
224 |
+
self.positional_embedding = tf.Variable(scale * tf.random.normal((input_resolution // patch_size) ** 2 + 1, width))
|
225 |
+
self.ln_pre = LayerNorm(width)
|
226 |
+
|
227 |
+
self.transformer = Transformer(width, layers, heads)
|
228 |
+
|
229 |
+
self.ln_post = LayerNorm(width)
|
230 |
+
self.proj = tf.Variable(scale * tf.random.normal(width, output_dim))
|
231 |
+
|
232 |
+
def __call__(self, x, train_flag=True):
|
233 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
234 |
+
x = tf.reshape(x, [x.shape[0], x.shape[1], -1]) # shape = [*, width, grid ** 2]
|
235 |
+
x = tf.transpose(x, (0, 2, 1)) # shape = [*, grid ** 2, width]
|
236 |
+
x = tf.concat([tf.cast(self.class_embedding, x.dtype) + tf.zeros([x.shape[0], 1, x.shape[-1]], dtype=x.dtype), x], axis=1) # shape = [*, grid ** 2 + 1, width]
|
237 |
+
x = x + tf.cast(self.positional_embedding, x.dtype)
|
238 |
+
x = self.ln_pre(x)
|
239 |
+
|
240 |
+
x = tf.transpose(x, (1, 0, 2)) # NLD -> LND
|
241 |
+
x = self.transformer(x)
|
242 |
+
x = tf.transpose(x, (1, 0, 2)) # LND -> NLD
|
243 |
+
|
244 |
+
x = self.ln_post(x[:, 0, :])
|
245 |
+
|
246 |
+
if self.proj is not None:
|
247 |
+
x = tf.matmul(x, self.proj)
|
248 |
+
|
249 |
+
return x
|
250 |
+
|
251 |
+
|
252 |
+
class CLIP(Model):
|
253 |
+
def __init__(self,
|
254 |
+
embed_dim: int,
|
255 |
+
# vision
|
256 |
+
image_resolution: int,
|
257 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
258 |
+
vision_width: int,
|
259 |
+
vision_patch_size: int,
|
260 |
+
# text
|
261 |
+
context_length: int,
|
262 |
+
vocab_size: int,
|
263 |
+
transformer_width: int,
|
264 |
+
transformer_heads: int,
|
265 |
+
transformer_layers: int
|
266 |
+
):
|
267 |
+
super(CLIP, self).__init__()
|
268 |
+
|
269 |
+
self.context_length = context_length
|
270 |
+
|
271 |
+
if isinstance(vision_layers, (tuple, list)):
|
272 |
+
vision_heads = vision_width * 32 // 64
|
273 |
+
self.visual = ModifiedResNet(
|
274 |
+
layers=vision_layers,
|
275 |
+
output_dim=embed_dim,
|
276 |
+
heads=vision_heads,
|
277 |
+
input_resolution=image_resolution,
|
278 |
+
width=vision_width
|
279 |
+
)
|
280 |
+
else:
|
281 |
+
vision_heads = vision_width // 64
|
282 |
+
self.visual = VisionTransformer(
|
283 |
+
input_resolution=image_resolution,
|
284 |
+
patch_size=vision_patch_size,
|
285 |
+
width=vision_width,
|
286 |
+
layers=vision_layers,
|
287 |
+
heads=vision_heads,
|
288 |
+
output_dim=embed_dim
|
289 |
+
)
|
290 |
+
|
291 |
+
self.transformer = Transformer(
|
292 |
+
width=transformer_width,
|
293 |
+
layers=transformer_layers,
|
294 |
+
heads=transformer_heads,
|
295 |
+
attn_mask=self.build_attention_mask()
|
296 |
+
)
|
297 |
+
|
298 |
+
self.vocab_size = vocab_size
|
299 |
+
self.token_embedding = tf.Variable(tf.random.normal((vocab_size, transformer_width),
|
300 |
+
stddev=0.02))
|
301 |
+
self.positional_embedding = tf.Variable(tf.random.normal((self.context_length, transformer_width),
|
302 |
+
stddev=0.01
|
303 |
+
))
|
304 |
+
self.ln_final = LayerNorm(transformer_width)
|
305 |
+
|
306 |
+
self.text_projection = tf.Variable(tf.random.normal((transformer_width, embed_dim),
|
307 |
+
stddev=self.transformer.width ** -0.5,
|
308 |
+
))
|
309 |
+
self.logit_scale = tf.Variable(tf.ones([]) * np.log(1 / 0.07))
|
310 |
+
|
311 |
+
def build_attention_mask(self):
|
312 |
+
mask = tf.ones((self.context_length, self.context_length))
|
313 |
+
mask = tf.linalg.band_part(mask, 0, -1) # zero out the upper diagonal
|
314 |
+
mask = mask * -1e9 # fill with -1e9
|
315 |
+
return mask
|
316 |
+
|
317 |
+
def encode_image(self, image):
|
318 |
+
return self.visual(image)
|
319 |
+
|
320 |
+
def encode_text(self, text):
|
321 |
+
x = tf.gather(self.token_embedding, text) # [batch_size, n_ctx, d_model]
|
322 |
+
|
323 |
+
x = x + self.positional_embedding
|
324 |
+
x = tf.transpose(x, (1, 0, 2)) # NLD -> LND
|
325 |
+
x = self.transformer(x)
|
326 |
+
x = tf.transpose(x, (1, 0, 2)) # LND -> NLD
|
327 |
+
x = self.ln_final(x)
|
328 |
+
|
329 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
330 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
331 |
+
x = tf.matmul(tf.gather_nd(x, tf.stack([tf.range(x.shape[0], dtype='int32'),
|
332 |
+
tf.argmax(text, axis=-1, output_type='int32')], axis=1)), self.text_projection)
|
333 |
+
|
334 |
+
return x
|
335 |
+
|
336 |
+
def __call__(self, image, text):
|
337 |
+
image_features = self.encode_image(image)
|
338 |
+
text_features = self.encode_text(text)
|
339 |
+
|
340 |
+
# normalized features
|
341 |
+
image_features = image_features / tf.norm(image_features, axis=1, keepdims=True)
|
342 |
+
text_features = text_features / tf.norm(text_features, axis=1, keepdims=True)
|
343 |
+
|
344 |
+
# cosine similarity as logits
|
345 |
+
logit_scale = tf.math.exp(self.logit_scale)
|
346 |
+
logits_per_image = tf.matmul(logit_scale * image_features, tf.transpose(text_features))
|
347 |
+
logits_per_text = tf.transpose(logits_per_image)
|
348 |
+
|
349 |
+
# shape = [global_batch_size, global_batch_size]
|
350 |
+
return logits_per_image, logits_per_text
|