File size: 3,883 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import numpy as np
import torch
from torch import nn
from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_

from openrec.modeling.common import Block, PatchEmbed
from openrec.modeling.encoders.svtrv2_lnconv import Feat2D, LastStage


class ViT(nn.Module):

    def __init__(
        self,
        img_size=[32, 128],
        patch_size=[4, 8],
        in_channels=3,
        out_channels=256,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        act_layer=nn.GELU,
        last_stage=False,
        feat2d=False,
        use_cls_token=False,
        **kwargs,
    ):
        super().__init__()
        self.img_size = img_size
        self.embed_dim = embed_dim
        self.out_channels = embed_dim
        self.use_cls_token = use_cls_token
        self.feat_sz = [
            img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        ]

        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels,
                                      embed_dim)
        num_patches = self.patch_embed.num_patches
        if use_cls_token:
            self.cls_token = nn.Parameter(
                torch.zeros([1, 1, embed_dim], dtype=torch.float32),
                requires_grad=True,
            )
            trunc_normal_(self.cls_token, mean=0, std=0.02)
            self.pos_embed = nn.Parameter(
                torch.zeros([1, num_patches + 1, embed_dim],
                            dtype=torch.float32),
                requires_grad=True,
            )
        else:
            self.pos_embed = nn.Parameter(
                torch.zeros([1, num_patches, embed_dim], dtype=torch.float32),
                requires_grad=True,
            )
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = np.linspace(0, drop_path_rate, depth)
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                act_layer=act_layer,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i],
                norm_layer=norm_layer,
            ) for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)
        self.last_stage = last_stage
        self.feat2d = feat2d
        if last_stage:
            self.out_channels = out_channels
            self.stages = LastStage(embed_dim, out_channels, last_drop=0.1)
        if feat2d:
            self.stages = Feat2D()
        trunc_normal_(self.pos_embed, mean=0, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, mean=0, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                zeros_(m.bias)
        if isinstance(m, nn.LayerNorm):
            zeros_(m.bias)
            ones_(m.weight)
        if isinstance(m, nn.Conv2d):
            kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed'}

    def forward(self, x):
        x = self.patch_embed(x)
        if self.use_cls_token:
            x = torch.concat([self.cls_token.tile([x.shape[0], 1, 1]), x], 1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        if self.use_cls_token:
            x = x[:, 1:, :]
        if self.last_stage:
            x, sz = self.stages(x, self.feat_sz)
        if self.feat2d:
            x, sz = self.stages(x, self.feat_sz)
        return x