bbexx commited on
Commit
45808ea
1 Parent(s): a8642cd
Files changed (2) hide show
  1. README.md +1 -1
  2. vitamin.py +49 -1
README.md CHANGED
@@ -81,7 +81,7 @@ Note: Panoptic dataset (ADE, CityScapes, MV) are with the metric of PQ. Semantic
81
 
82
  | image encoder | image size | VQAv2 | GQA | VizWiz | SQA | T-VQA | POPE | MME | MM-Bench | MM-B-CN | SEED | LLaVA-Wild | MM-Vet |
83
  |---------------|----------|-------|------|--------|------|-------|------|------|----------|---------|------|------------|--------|
84
- | ViTamin-L | 224 | 78.4 | 61.6 | 51.1 | 66.9 | 58.7 | 84.6 | 1421 | 65.4 | 58.4 | 57.7 | 64.5 | 33.6 |
85
  | ViTamin-L | 384 | 78.9 | 61.6 | 55.4 | 67.6 | 59.8 | 85.5 | 1447 | 64.5 | 58.3 | 57.9 | 66.1 | 33.6 |
86
 
87
 
 
81
 
82
  | image encoder | image size | VQAv2 | GQA | VizWiz | SQA | T-VQA | POPE | MME | MM-Bench | MM-B-CN | SEED | LLaVA-Wild | MM-Vet |
83
  |---------------|----------|-------|------|--------|------|-------|------|------|----------|---------|------|------------|--------|
84
+ | ViTamin-L | 336 | 78.4 | 61.6 | 51.1 | 66.9 | 58.7 | 84.6 | 1421 | 65.4 | 58.4 | 57.7 | 64.5 | 33.6 |
85
  | ViTamin-L | 384 | 78.9 | 61.6 | 55.4 | 67.6 | 59.8 | 85.5 | 1447 | 64.5 | 58.3 | 57.9 | 66.1 | 33.6 |
86
 
87
 
vitamin.py CHANGED
@@ -31,7 +31,7 @@ from torch.utils.checkpoint import checkpoint
31
  from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_
32
 
33
 
34
- from timm.layers import to_2tuple, DropPath, Format
35
  from timm.layers.norm_act import _create_act
36
  from timm.models._registry import register_model
37
  from timm.models._manipulate import named_apply, checkpoint_seq
@@ -335,6 +335,39 @@ class HybridEmbed(nn.Module):
335
  x = x.flatten(2).transpose(1, 2)
336
  return x
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  class ViTamin(nn.Module):
339
  """ hack timm VisionTransformer
340
  """
@@ -563,6 +596,21 @@ class ViTamin(nn.Module):
563
  self.global_pool = global_pool
564
  self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
567
  x = self.patch_embed(x)
568
  if self.is_pos_embed:
 
31
  from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_
32
 
33
 
34
+ from timm.layers import to_2tuple, DropPath, Format # , trunc_normal_
35
  from timm.layers.norm_act import _create_act
36
  from timm.models._registry import register_model
37
  from timm.models._manipulate import named_apply, checkpoint_seq
 
335
  x = x.flatten(2).transpose(1, 2)
336
  return x
337
 
338
+ def _trunc_normal_(tensor, mean, std, a, b):
339
+ # rewrite timm trunc normal
340
+ def norm_cdf(x):
341
+ # Computes standard normal cumulative distribution function
342
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
343
+
344
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
345
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
346
+ "The distribution of values may be incorrect.",
347
+ stacklevel=2)
348
+
349
+ l = norm_cdf((a - mean) / std)
350
+ u = norm_cdf((b - mean) / std)
351
+
352
+ # Uniformly fill tensor with values from [l, u], then translate to
353
+ # [2l-1, 2u-1].
354
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
355
+
356
+ # Use inverse cdf transform for normal distribution to get truncated standard normal
357
+ # tensor.erfinv_() # NOTE: deleted as "erfinv_cuda" not implemented for 'BFloat16'
358
+
359
+ # Transform to proper mean, std
360
+ tensor.mul_(std * math.sqrt(2.))
361
+ tensor.add_(mean)
362
+
363
+ # Clamp to ensure it's in the proper range
364
+ tensor.clamp_(min=a, max=b)
365
+ return tensor
366
+
367
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
368
+ with torch.no_grad():
369
+ return _trunc_normal_(tensor, mean, std, a, b)
370
+
371
  class ViTamin(nn.Module):
372
  """ hack timm VisionTransformer
373
  """
 
596
  self.global_pool = global_pool
597
  self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
598
 
599
+ def _pos_embed(self, x):
600
+ if self.no_embed_class:
601
+ # deit-3, updated JAX (big vision)
602
+ # position embedding does not overlap with class token, add then concat
603
+ x = x + self.pos_embed
604
+ if self.cls_token is not None:
605
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
606
+ else:
607
+ # original timm, JAX, and deit vit impl
608
+ # pos_embed has entry for class token, concat then add
609
+ if self.cls_token is not None:
610
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
611
+ x = x + self.pos_embed
612
+ return self.pos_drop(x)
613
+
614
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
615
  x = self.patch_embed(x)
616
  if self.is_pos_embed: