File size: 4,663 Bytes
aa7fb02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
'''
   author: yulong-XJTU
'''
import torch
from torch import nn
import torch.nn.functional as F
import copy
from AttDes.models.transformer import Transformer, subsequent_mask, ModelOne, Model005, Model006
from axial_positional_embedding import AxialPositionalEmbedding
from AttDes.models.resblock import BottleneckBlock
from random import randint
from einops import rearrange

def clone(module,N):
    '''copy the given module N times'''
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class PrefixLM(nn.Module):
    def __init__(
            self, des_len, obj_len, tgt_len,
            d_model=512,
            input_resolution=224,
            patch_size=16,
            num_text_tokens=10000,
            txt_seq_len=256,
            prefix_txt_len=25,
            target_txt_len=52,
            max_trunc_txt_len=15,
            heads=8,
            enc_depth=12,
            dec_depth=12,
            d_ff=1024,
            dropout=0.,
    ):
        super(PrefixLM,self).__init__()
        assert input_resolution % patch_size==0 and max_trunc_txt_len<=prefix_txt_len and max_trunc_txt_len<txt_seq_len
        self.ResNet = nn.Sequential(*[nn.Conv2d(in_channels=3, out_channels=64, kernel_size=patch_size, stride=patch_size, bias=True),
                                    BottleneckBlock(in_channels=64,out_channels=256,bottleneck_channels=64,),
                                    BottleneckBlock(in_channels=256,out_channels=d_model,bottleneck_channels=128)])
        self.des_len = des_len
        self.obj_len = obj_len
        self.tgt_len = tgt_len
        self.txt_embed = nn.Embedding(num_text_tokens, d_model, padding_idx=0)
        self.txt_pos_embed = nn.Embedding(self.des_len,d_model)
        image_fmap_size = input_resolution // patch_size    # 448 // 16
        self.img_tokens_len=image_fmap_size ** 2
        # self.img_pos_embed=nn.Embedding(self.img_tokens_len,d_model)
        self.img_pos_embed = AxialPositionalEmbedding(d_model, axial_shape=(image_fmap_size, image_fmap_size))
        self.txt_seq_len = txt_seq_len
        self.target_txt_len = target_txt_len
        self.prefix_txt_len = prefix_txt_len

        self.max_trunc_txt_len=max_trunc_txt_len
        self.num_text_tokens = num_text_tokens
        self.dim_embed=d_model
        self.input_resolution=input_resolution
        self.patch_size=patch_size
        # self.temperature = nn.Parameter(torch.tensor(1.))       # 论文中没提到
        self.transformer=Transformer(d_model,heads,enc_depth,dec_depth,d_ff,dropout=dropout)
        self.ModelOne = Model005(d_model,heads,enc_depth,dec_depth,d_ff,dropout=dropout)
        self.to_logits = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, self.num_text_tokens)
        )

    def forward(self, img, des, obj, tgt, return_loss=False):
        device = des.device
        n = des.shape[0]
        img_emed = self.ResNet(img)
        img_emed = rearrange(img_emed,'b c h w -> b (h w) c')
        img_emed = img_emed + self.img_pos_embed(img_emed)
        del img
        #add<CLS>, if you change the tokenizer, don't forget  to change the token ID. another [SEP] token is added at the ending(in the tokenizer.py,please check.)
        tgt = F.pad(tgt, (1, 0), value=4)
        labels = tgt[:,1:]
        tgt = tgt[:,:-1]
        # print('des:', torch.min(des), torch.max(des))
        des_embed = self.txt_embed(des)
        des_embed = des_embed + self.txt_pos_embed(torch.arange(self.des_len, device=device))

        obj_embed = self.txt_embed(obj)
        obj_embed = obj_embed + self.txt_pos_embed(torch.arange(self.obj_len, device=device))

        tgt_embed = self.txt_embed(tgt)
        tgt_embed = tgt_embed + self.txt_pos_embed(torch.arange(self.tgt_len, device=device))
        tgt_mask = subsequent_mask(self.tgt_len).to(device)

        # baseline
        # prefix = torch.cat((img_emed, des_embed, obj_embed), dim=1)
        # tgt_mask = subsequent_mask(self.tgt_len).to(device)
        # out = self.transformer(prefix, tgt_embed, tgt_mask=tgt_mask)

        # ModelOne

        out = Model005(q=obj_embed, k=img_emed, v=img_emed,
                            tgt_embeded=tgt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
                            tgt_mask=tgt_mask)

        logits = self.to_logits(out)
        return logits, labels
        # if not return_loss:
        #     return logits
        # # temp = self.temperature.exp()
        # logits = rearrange(logits, 'b n c -> b c n')
        # # logits=logits*temp #带温度参数
        # loss=F.cross_entropy(logits,labels,ignore_index=0)
        # return loss