simonJJJ commited on
Commit
4de9cce
1 Parent(s): b3bea9b
Files changed (3) hide show
  1. modeling_qwen.py +10 -153
  2. tokenization_qwen.py +27 -7
  3. visual.py +70 -19
modeling_qwen.py CHANGED
@@ -69,44 +69,7 @@ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for remo
69
 
70
  apply_rotary_emb_func = None
71
  rms_norm = None
72
- flash_attn_unpadded_func = None
73
-
74
-
75
- def _import_flash_attn():
76
- global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
77
- try:
78
- from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
79
- apply_rotary_emb_func = __apply_rotary_emb_func
80
- except ImportError:
81
- logger.warn(
82
- "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
83
- "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
84
- )
85
 
86
- try:
87
- from flash_attn.ops.rms_norm import rms_norm as __rms_norm
88
- rms_norm = __rms_norm
89
- except ImportError:
90
- logger.warn(
91
- "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
92
- "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
93
- )
94
-
95
- try:
96
- import flash_attn
97
- if not hasattr(flash_attn, '__version__'):
98
- from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
99
- else:
100
- if int(flash_attn.__version__.split(".")[0]) >= 2:
101
- from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
102
- else:
103
- from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
104
- flash_attn_unpadded_func = __flash_attn_unpadded_func
105
- except ImportError:
106
- logger.warn(
107
- "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
108
- "https://github.com/Dao-AILab/flash-attention"
109
- )
110
 
111
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
112
  def _make_causal_mask(
@@ -141,70 +104,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
141
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
142
 
143
 
144
- class FlashSelfAttention(torch.nn.Module):
145
- def __init__(
146
- self,
147
- causal=False,
148
- softmax_scale=None,
149
- attention_dropout=0.0,
150
- ):
151
- super().__init__()
152
- assert flash_attn_unpadded_func is not None, (
153
- "Please install FlashAttention first, " "e.g., with pip install flash-attn"
154
- )
155
- assert (
156
- rearrange is not None
157
- ), "Please install einops first, e.g., with pip install einops"
158
- self.causal = causal
159
- self.softmax_scale = softmax_scale
160
- self.dropout_p = attention_dropout
161
-
162
- def forward(self, q, k, v):
163
- assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
164
- assert all((i.is_cuda for i in (q, k, v)))
165
- batch_size, seqlen_q = q.shape[0], q.shape[1]
166
- seqlen_k = k.shape[1]
167
- q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
168
- cu_seqlens_q = torch.arange(
169
- 0,
170
- (batch_size + 1) * seqlen_q,
171
- step=seqlen_q,
172
- dtype=torch.int32,
173
- device=q.device,
174
- )
175
-
176
- if self.training:
177
- assert seqlen_k == seqlen_q
178
-
179
- is_causal = self.causal
180
- cu_seqlens_k = cu_seqlens_q
181
- else:
182
- is_causal = seqlen_q == seqlen_k
183
- cu_seqlens_k = torch.arange(
184
- 0,
185
- (batch_size + 1) * seqlen_k,
186
- step=seqlen_k,
187
- dtype=torch.int32,
188
- device=q.device,
189
- )
190
- self.dropout_p = 0
191
- output = flash_attn_unpadded_func(
192
- q,
193
- k,
194
- v,
195
- cu_seqlens_q,
196
- cu_seqlens_k,
197
- seqlen_q,
198
- seqlen_k,
199
- self.dropout_p,
200
- softmax_scale=self.softmax_scale,
201
- causal=is_causal,
202
- )
203
-
204
- output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
205
- return output
206
-
207
-
208
  class QWenAttention(nn.Module):
209
  def __init__(self, config):
210
  super().__init__()
@@ -225,7 +124,6 @@ class QWenAttention(nn.Module):
225
  self.num_heads = config.num_attention_heads
226
  self.head_dim = self.hidden_size // self.num_heads
227
 
228
- self.use_flash_attn = config.use_flash_attn
229
  self.scale_attn_weights = True
230
 
231
  self.projection_size = config.kv_channels * config.num_attention_heads
@@ -242,15 +140,6 @@ class QWenAttention(nn.Module):
242
  )
243
 
244
  self.is_fp32 = not (config.bf16 or config.fp16)
245
- if (
246
- self.use_flash_attn
247
- and flash_attn_unpadded_func is not None
248
- and not self.is_fp32
249
- ):
250
- self.core_attention_flash = FlashSelfAttention(
251
- causal=True, attention_dropout=config.attn_dropout_prob
252
- )
253
-
254
  self.bf16 = config.bf16
255
 
256
  if config.rotary_pct == 1.0:
@@ -453,40 +342,20 @@ class QWenAttention(nn.Module):
453
  logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
454
  query = query * logn_tensor.expand_as(query)
455
 
456
- if (
457
- self.use_flash_attn
458
- and flash_attn_unpadded_func is not None
459
- and not self.is_fp32
460
- and query.is_cuda
461
- ):
462
- q, k, v = query, key, value
463
- context_layer = self.core_attention_flash(q, k, v)
464
-
465
- context_layer = rearrange(
466
- context_layer, "b s h d -> b s (h d)"
467
- ).contiguous()
468
- else:
469
- query = query.permute(0, 2, 1, 3)
470
- key = key.permute(0, 2, 1, 3)
471
- value = value.permute(0, 2, 1, 3)
472
- attn_output, attn_weight = self._attn(
473
- query, key, value, attention_mask, head_mask
474
- )
475
- context_layer = self._merge_heads(
476
- attn_output, self.num_heads, self.head_dim
477
- )
478
 
479
  attn_output = self.c_proj(context_layer)
480
  outputs = (attn_output, present)
481
  if output_attentions:
482
- if (
483
- self.use_flash_attn
484
- and flash_attn_unpadded_func is not None
485
- and not self.is_fp32
486
- ):
487
- raise ValueError("Cannot output attentions while using flash-attn")
488
- else:
489
- outputs += (attn_weight,)
490
 
491
  return outputs
492
 
@@ -882,18 +751,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
882
  logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
883
  elif SUPPORT_FP16:
884
  logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
885
-
886
- if config.use_flash_attn == "auto":
887
- if config.bf16 or config.fp16:
888
- logger.warn("Try importing flash-attention for faster inference...")
889
- config.use_flash_attn = True
890
- else:
891
- config.use_flash_attn = False
892
- if config.use_flash_attn and config.fp32:
893
- logger.warn("Flash attention will be disabled because it does NOT support fp32.")
894
-
895
- if config.use_flash_attn:
896
- _import_flash_attn()
897
 
898
  self.transformer = QWenModel(config)
899
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
69
 
70
  apply_rotary_emb_func = None
71
  rms_norm = None
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
75
  def _make_causal_mask(
 
104
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
105
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  class QWenAttention(nn.Module):
108
  def __init__(self, config):
109
  super().__init__()
 
124
  self.num_heads = config.num_attention_heads
125
  self.head_dim = self.hidden_size // self.num_heads
126
 
 
127
  self.scale_attn_weights = True
128
 
129
  self.projection_size = config.kv_channels * config.num_attention_heads
 
140
  )
141
 
142
  self.is_fp32 = not (config.bf16 or config.fp16)
 
 
 
 
 
 
 
 
 
143
  self.bf16 = config.bf16
144
 
145
  if config.rotary_pct == 1.0:
 
342
  logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
343
  query = query * logn_tensor.expand_as(query)
344
 
345
+ query = query.permute(0, 2, 1, 3)
346
+ key = key.permute(0, 2, 1, 3)
347
+ value = value.permute(0, 2, 1, 3)
348
+ attn_output, attn_weight = self._attn(
349
+ query, key, value, attention_mask, head_mask
350
+ )
351
+ context_layer = self._merge_heads(
352
+ attn_output, self.num_heads, self.head_dim
353
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
  attn_output = self.c_proj(context_layer)
356
  outputs = (attn_output, present)
357
  if output_attentions:
358
+ outputs += (attn_weight,)
 
 
 
 
 
 
 
359
 
360
  return outputs
361
 
 
751
  logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
752
  elif SUPPORT_FP16:
753
  logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
 
 
 
 
 
 
 
 
 
 
 
 
754
 
755
  self.transformer = QWenModel(config)
756
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
tokenization_qwen.py CHANGED
@@ -10,7 +10,7 @@ import logging
10
  import os
11
  import requests
12
  import unicodedata
13
- from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable
14
 
15
  import tiktoken
16
  import numpy as np
@@ -359,6 +359,22 @@ class QWenTokenizer(PreTrainedTokenizer):
359
  _encode_vl_info,
360
  )
361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  def _fetch_latest_picture(self, response, history):
363
  if history is None:
364
  history = []
@@ -377,15 +393,19 @@ class QWenTokenizer(PreTrainedTokenizer):
377
  bbox = tuple(map(int, ele['box'].replace('(', '').replace(')', '').split(',')))
378
  assert len(bbox) == 4
379
  output.append({'box': bbox})
380
- if i > 0 and 'ref' in list_format[i-1]:
381
- output[-1]['ref'] = list_format[i-1]['ref'].strip()
 
 
 
 
382
  return output
383
 
384
  def draw_bbox_on_latest_picture(
385
  self,
386
  response,
387
  history=None,
388
- ):
389
  image = self._fetch_latest_picture(response, history)
390
  if image is None:
391
  return None
@@ -399,14 +419,14 @@ class QWenTokenizer(PreTrainedTokenizer):
399
  boxes = self._fetch_all_box_with_ref(response)
400
  if not boxes:
401
  return None
402
- fnt = ImageFont.truetype("SimSun.ttf", 20)
403
  draw = ImageDraw.Draw(image)
404
  for box in boxes:
405
  x1, y1, x2, y2 = box['box']
406
  x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h))
407
- draw.rectangle((x1, y1, x2, y2), outline='red', width=2)
408
  if 'ref' in box:
409
- draw.text((x1, y1), box['ref'], fill='red', font=fnt)
410
  return image
411
 
412
 
 
10
  import os
11
  import requests
12
  import unicodedata
13
+ from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Optional
14
 
15
  import tiktoken
16
  import numpy as np
 
359
  _encode_vl_info,
360
  )
361
 
362
+ def from_list_format(self, list_format: List[Dict]):
363
+ text = ''
364
+ for ele in list_format:
365
+ if 'image' in ele:
366
+ text += self.image_start_tag + ele['image'] + self.image_end_tag
367
+ elif 'text' in ele:
368
+ text += ele['text']
369
+ elif 'box' in ele:
370
+ if 'ref' in ele:
371
+ text += self.ref_start_tag + ele['ref'] + self.ref_end_tag
372
+ for box in ele['box']:
373
+ text += self.box_start_tag + '(%d,%d),(%d,%d)' % (box[0], box[1], box[2], box[3]) + self.box_end_tag
374
+ else:
375
+ raise ValueError("Unsupport element: " + str(ele))
376
+ return text
377
+
378
  def _fetch_latest_picture(self, response, history):
379
  if history is None:
380
  history = []
 
393
  bbox = tuple(map(int, ele['box'].replace('(', '').replace(')', '').split(',')))
394
  assert len(bbox) == 4
395
  output.append({'box': bbox})
396
+
397
+ ref_idx = i - 1
398
+ while ref_idx >= 0 and 'box' in list_format[ref_idx]:
399
+ ref_idx -= 1
400
+ if ref_idx >= 0 and 'ref' in list_format[ref_idx]:
401
+ output[-1]['ref'] = list_format[ref_idx]['ref'].strip()
402
  return output
403
 
404
  def draw_bbox_on_latest_picture(
405
  self,
406
  response,
407
  history=None,
408
+ ) -> Optional[Image.Image]:
409
  image = self._fetch_latest_picture(response, history)
410
  if image is None:
411
  return None
 
419
  boxes = self._fetch_all_box_with_ref(response)
420
  if not boxes:
421
  return None
422
+ fnt = ImageFont.truetype("SimSun.ttf", 50)
423
  draw = ImageDraw.Draw(image)
424
  for box in boxes:
425
  x1, y1, x2, y2 = box['box']
426
  x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h))
427
+ draw.rectangle((x1, y1, x2, y2), outline='red', width=4)
428
  if 'ref' in box:
429
+ draw.text((x1, y1), box['ref'], fill='yellow', font=fnt)
430
  return image
431
 
432
 
visual.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  from collections import OrderedDict
2
  import math
3
  import requests
@@ -5,11 +10,11 @@ from io import BytesIO
5
  from functools import partial
6
  from PIL import Image
7
  from typing import Callable, Optional, Sequence, Tuple, List
 
8
 
9
  import torch
10
  from torch import nn
11
  from torch.nn import functional as F
12
- from torch.utils.checkpoint import checkpoint
13
  from torch.nn.init import trunc_normal_
14
  from torchvision import transforms
15
  from torchvision.transforms import InterpolationMode
@@ -33,8 +38,64 @@ def get_abs_pos(abs_pos, tgt_size):
33
  else:
34
  return abs_pos
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  class Resampler(nn.Module):
 
 
 
 
 
 
38
  def __init__(
39
  self,
40
  grid_size,
@@ -48,7 +109,9 @@ class Resampler(nn.Module):
48
  self.embed_dim = embed_dim
49
  self.num_heads = num_heads
50
 
51
- self.pos_embed = nn.Parameter(torch.randn(embed_dim, grid_size)).requires_grad_(False)
 
 
52
 
53
  self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
54
  trunc_normal_(self.query, std=.02)
@@ -234,7 +297,7 @@ class VisualAttentionBlock(nn.Module):
234
  return x
235
 
236
 
237
- class Transformer(nn.Module):
238
  def __init__(
239
  self,
240
  width: int,
@@ -247,7 +310,6 @@ class Transformer(nn.Module):
247
  super().__init__()
248
  self.width = width
249
  self.layers = layers
250
- self.grad_checkpointing = False
251
 
252
  self.resblocks = nn.ModuleList([
253
  VisualAttentionBlock(
@@ -263,11 +325,7 @@ class Transformer(nn.Module):
263
 
264
  def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
265
  for r in self.resblocks:
266
- if self.grad_checkpointing and not torch.jit.is_scripting():
267
- # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
268
- x = checkpoint(r, x, None, None, attn_mask)
269
- else:
270
- x = r(x, attn_mask=attn_mask)
271
  return x
272
 
273
 
@@ -306,13 +364,13 @@ class VisionTransformer(nn.Module):
306
 
307
  # class embeddings and positional embeddings
308
  scale = width ** -0.5
309
- self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1], width))
310
 
311
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
312
  act_layer = nn.GELU
313
 
314
  self.ln_pre = norm_layer(width)
315
- self.transformer = Transformer(
316
  width,
317
  layers,
318
  heads,
@@ -331,10 +389,6 @@ class VisionTransformer(nn.Module):
331
  self.ln_post = norm_layer(output_dim)
332
  self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
333
 
334
- @torch.jit.ignore
335
- def set_grad_checkpointing(self, enable=True):
336
- self.transformer.grad_checkpointing = enable
337
-
338
  def forward(self, x: torch.Tensor):
339
  x = x.to(
340
  dtype=self.transformer.get_cast_dtype(),
@@ -353,8 +407,7 @@ class VisionTransformer(nn.Module):
353
  x = self.transformer(x)
354
  x = x.permute(1, 0, 2) # LND -> NLD
355
 
356
- if self.attn_pool:
357
- x = self.attn_pool(x)
358
  x = self.ln_post(x)
359
  x = x @ self.proj
360
 
@@ -365,8 +418,6 @@ class VisionTransformer(nn.Module):
365
  for image_path in image_paths:
366
  if image_path.startswith("http://") or image_path.startswith("https://"):
367
  image = Image.open(requests.get(image_path, stream=True).raw)
368
- elif image_path.startswith("oss://"):
369
- raise NotImplementedError
370
  else:
371
  image = Image.open(image_path)
372
  image = image.convert("RGB")
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
  from collections import OrderedDict
7
  import math
8
  import requests
 
10
  from functools import partial
11
  from PIL import Image
12
  from typing import Callable, Optional, Sequence, Tuple, List
13
+ import numpy as np
14
 
15
  import torch
16
  from torch import nn
17
  from torch.nn import functional as F
 
18
  from torch.nn.init import trunc_normal_
19
  from torchvision import transforms
20
  from torchvision.transforms import InterpolationMode
 
38
  else:
39
  return abs_pos
40
 
41
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
42
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
43
+ """
44
+ grid_size: int of the grid height and width
45
+ return:
46
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
47
+ """
48
+ grid_h = np.arange(grid_size, dtype=np.float32)
49
+ grid_w = np.arange(grid_size, dtype=np.float32)
50
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
51
+ grid = np.stack(grid, axis=0)
52
+
53
+ grid = grid.reshape([2, 1, grid_size, grid_size])
54
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
55
+ if cls_token:
56
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
57
+ return pos_embed
58
+
59
+
60
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
61
+ assert embed_dim % 2 == 0
62
+
63
+ # use half of dimensions to encode grid_h
64
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
65
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
66
+
67
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
68
+ return emb
69
+
70
+
71
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
72
+ """
73
+ embed_dim: output dimension for each position
74
+ pos: a list of positions to be encoded: size (M,)
75
+ out: (M, D)
76
+ """
77
+ assert embed_dim % 2 == 0
78
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
79
+ omega /= embed_dim / 2.
80
+ omega = 1. / 10000**omega # (D/2,)
81
+
82
+ pos = pos.reshape(-1) # (M,)
83
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
84
+
85
+ emb_sin = np.sin(out) # (M, D/2)
86
+ emb_cos = np.cos(out) # (M, D/2)
87
+
88
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
89
+ return emb
90
+
91
 
92
  class Resampler(nn.Module):
93
+ """
94
+ A 2D perceiver-resampler network with one cross attention layers by
95
+ (grid_size**2) learnable queries and 2d sincos pos_emb
96
+ Outputs:
97
+ A tensor with the shape of (grid_size**2, embed_dim)
98
+ """
99
  def __init__(
100
  self,
101
  grid_size,
 
109
  self.embed_dim = embed_dim
110
  self.num_heads = num_heads
111
 
112
+ self.pos_embed = nn.Parameter(
113
+ torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
114
+ ).requires_grad_(False)
115
 
116
  self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
117
  trunc_normal_(self.query, std=.02)
 
297
  return x
298
 
299
 
300
+ class TransformerBlock(nn.Module):
301
  def __init__(
302
  self,
303
  width: int,
 
310
  super().__init__()
311
  self.width = width
312
  self.layers = layers
 
313
 
314
  self.resblocks = nn.ModuleList([
315
  VisualAttentionBlock(
 
325
 
326
  def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
327
  for r in self.resblocks:
328
+ x = r(x, attn_mask=attn_mask)
 
 
 
 
329
  return x
330
 
331
 
 
364
 
365
  # class embeddings and positional embeddings
366
  scale = width ** -0.5
367
+ self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
368
 
369
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
370
  act_layer = nn.GELU
371
 
372
  self.ln_pre = norm_layer(width)
373
+ self.transformer = TransformerBlock(
374
  width,
375
  layers,
376
  heads,
 
389
  self.ln_post = norm_layer(output_dim)
390
  self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
391
 
 
 
 
 
392
  def forward(self, x: torch.Tensor):
393
  x = x.to(
394
  dtype=self.transformer.get_cast_dtype(),
 
407
  x = self.transformer(x)
408
  x = x.permute(1, 0, 2) # LND -> NLD
409
 
410
+ x = self.attn_pool(x)
 
411
  x = self.ln_post(x)
412
  x = x @ self.proj
413
 
 
418
  for image_path in image_paths:
419
  if image_path.startswith("http://") or image_path.startswith("https://"):
420
  image = Image.open(requests.get(image_path, stream=True).raw)
 
 
421
  else:
422
  image = Image.open(image_path)
423
  image = image.convert("RGB")