update
Browse files- README.md +1 -1
- vitamin.py +49 -1
README.md
CHANGED
@@ -81,7 +81,7 @@ Note: Panoptic dataset (ADE, CityScapes, MV) are with the metric of PQ. Semantic
|
|
81 |
|
82 |
| image encoder | image size | VQAv2 | GQA | VizWiz | SQA | T-VQA | POPE | MME | MM-Bench | MM-B-CN | SEED | LLaVA-Wild | MM-Vet |
|
83 |
|---------------|----------|-------|------|--------|------|-------|------|------|----------|---------|------|------------|--------|
|
84 |
-
| ViTamin-L |
|
85 |
| ViTamin-L | 384 | 78.9 | 61.6 | 55.4 | 67.6 | 59.8 | 85.5 | 1447 | 64.5 | 58.3 | 57.9 | 66.1 | 33.6 |
|
86 |
|
87 |
|
|
|
81 |
|
82 |
| image encoder | image size | VQAv2 | GQA | VizWiz | SQA | T-VQA | POPE | MME | MM-Bench | MM-B-CN | SEED | LLaVA-Wild | MM-Vet |
|
83 |
|---------------|----------|-------|------|--------|------|-------|------|------|----------|---------|------|------------|--------|
|
84 |
+
| ViTamin-L | 336 | 78.4 | 61.6 | 51.1 | 66.9 | 58.7 | 84.6 | 1421 | 65.4 | 58.4 | 57.7 | 64.5 | 33.6 |
|
85 |
| ViTamin-L | 384 | 78.9 | 61.6 | 55.4 | 67.6 | 59.8 | 85.5 | 1447 | 64.5 | 58.3 | 57.9 | 66.1 | 33.6 |
|
86 |
|
87 |
|
vitamin.py
CHANGED
@@ -31,7 +31,7 @@ from torch.utils.checkpoint import checkpoint
|
|
31 |
from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_
|
32 |
|
33 |
|
34 |
-
from timm.layers import to_2tuple, DropPath, Format
|
35 |
from timm.layers.norm_act import _create_act
|
36 |
from timm.models._registry import register_model
|
37 |
from timm.models._manipulate import named_apply, checkpoint_seq
|
@@ -335,6 +335,39 @@ class HybridEmbed(nn.Module):
|
|
335 |
x = x.flatten(2).transpose(1, 2)
|
336 |
return x
|
337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
class ViTamin(nn.Module):
|
339 |
""" hack timm VisionTransformer
|
340 |
"""
|
@@ -563,6 +596,21 @@ class ViTamin(nn.Module):
|
|
563 |
self.global_pool = global_pool
|
564 |
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
566 |
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
567 |
x = self.patch_embed(x)
|
568 |
if self.is_pos_embed:
|
|
|
31 |
from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_
|
32 |
|
33 |
|
34 |
+
from timm.layers import to_2tuple, DropPath, Format # , trunc_normal_
|
35 |
from timm.layers.norm_act import _create_act
|
36 |
from timm.models._registry import register_model
|
37 |
from timm.models._manipulate import named_apply, checkpoint_seq
|
|
|
335 |
x = x.flatten(2).transpose(1, 2)
|
336 |
return x
|
337 |
|
338 |
+
def _trunc_normal_(tensor, mean, std, a, b):
|
339 |
+
# rewrite timm trunc normal
|
340 |
+
def norm_cdf(x):
|
341 |
+
# Computes standard normal cumulative distribution function
|
342 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
343 |
+
|
344 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
345 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
346 |
+
"The distribution of values may be incorrect.",
|
347 |
+
stacklevel=2)
|
348 |
+
|
349 |
+
l = norm_cdf((a - mean) / std)
|
350 |
+
u = norm_cdf((b - mean) / std)
|
351 |
+
|
352 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
353 |
+
# [2l-1, 2u-1].
|
354 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
355 |
+
|
356 |
+
# Use inverse cdf transform for normal distribution to get truncated standard normal
|
357 |
+
# tensor.erfinv_() # NOTE: deleted as "erfinv_cuda" not implemented for 'BFloat16'
|
358 |
+
|
359 |
+
# Transform to proper mean, std
|
360 |
+
tensor.mul_(std * math.sqrt(2.))
|
361 |
+
tensor.add_(mean)
|
362 |
+
|
363 |
+
# Clamp to ensure it's in the proper range
|
364 |
+
tensor.clamp_(min=a, max=b)
|
365 |
+
return tensor
|
366 |
+
|
367 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
368 |
+
with torch.no_grad():
|
369 |
+
return _trunc_normal_(tensor, mean, std, a, b)
|
370 |
+
|
371 |
class ViTamin(nn.Module):
|
372 |
""" hack timm VisionTransformer
|
373 |
"""
|
|
|
596 |
self.global_pool = global_pool
|
597 |
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
598 |
|
599 |
+
def _pos_embed(self, x):
|
600 |
+
if self.no_embed_class:
|
601 |
+
# deit-3, updated JAX (big vision)
|
602 |
+
# position embedding does not overlap with class token, add then concat
|
603 |
+
x = x + self.pos_embed
|
604 |
+
if self.cls_token is not None:
|
605 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
606 |
+
else:
|
607 |
+
# original timm, JAX, and deit vit impl
|
608 |
+
# pos_embed has entry for class token, concat then add
|
609 |
+
if self.cls_token is not None:
|
610 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
611 |
+
x = x + self.pos_embed
|
612 |
+
return self.pos_drop(x)
|
613 |
+
|
614 |
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
615 |
x = self.patch_embed(x)
|
616 |
if self.is_pos_embed:
|