File size: 3,753 Bytes
960dfdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from dataclasses import dataclass

import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel

from .configuration_fidnet_v3 import LayoutDmFIDNetV3Config


@dataclass
class LayoutDmFIDNetV3Output(object):
    logit_dict: torch.Tensor
    logit_cls: torch.Tensor
    bbox_pred: torch.Tensor


class TransformerWithToken(nn.Module):
    def __init__(self, d_model: int, nhead: int, dim_feedforward: int, num_layers: int):
        super().__init__()

        self.token = nn.Parameter(torch.randn(1, 1, d_model))
        token_mask = torch.zeros(1, 1, dtype=torch.bool)
        self.register_buffer("token_mask", token_mask)

        self.core = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
            ),
            num_layers=num_layers,
        )

    def forward(self, x, src_key_padding_mask):
        # x: [N, B, E]
        # padding_mask: [B, N]
        #   `False` for valid values
        #   `True` for padded values

        B = x.size(1)

        token = self.token.expand(-1, B, -1)
        x = torch.cat([token, x], dim=0)

        token_mask = self.token_mask.expand(B, -1)
        padding_mask = torch.cat([token_mask, src_key_padding_mask], dim=1)

        x = self.core(x, src_key_padding_mask=padding_mask)

        return x


class LayoutDmFIDNetV3(PreTrainedModel):
    config_class = LayoutDmFIDNetV3Config

    def __init__(self, config: LayoutDmFIDNetV3Config):
        super().__init__(config)
        self.config = config

        # encoder
        self.emb_label = nn.Embedding(config.num_labels, config.d_model)
        self.fc_bbox = nn.Linear(4, config.d_model)
        self.enc_fc_in = nn.Linear(config.d_model * 2, config.d_model)

        self.enc_transformer = TransformerWithToken(
            d_model=config.d_model,
            dim_feedforward=config.d_model // 2,
            nhead=config.nhead,
            num_layers=config.num_layers,
        )

        self.fc_out_disc = nn.Linear(config.d_model, 1)

        # decoder
        self.pos_token = nn.Parameter(torch.rand(config.max_bbox, 1, config.d_model))
        self.dec_fc_in = nn.Linear(config.d_model * 2, config.d_model)

        te = nn.TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=config.nhead,
            dim_feedforward=config.d_model // 2,
        )
        self.dec_transformer = nn.TransformerEncoder(te, num_layers=config.num_layers)

        self.fc_out_cls = nn.Linear(config.d_model, config.num_labels)
        self.fc_out_bbox = nn.Linear(config.d_model, 4)

    def extract_features(self, bbox, label, padding_mask):
        b = self.fc_bbox(bbox)
        l = self.emb_label(label)
        x = self.enc_fc_in(torch.cat([b, l], dim=-1))
        x = torch.relu(x).permute(1, 0, 2)
        x = self.enc_transformer(x, padding_mask)
        return x[0]

    def forward(self, bbox, label, padding_mask):
        B, N, _ = bbox.size()
        x = self.extract_features(bbox, label, padding_mask)

        logit_disc = self.fc_out_disc(x).squeeze(-1)

        x = x.unsqueeze(0).expand(N, -1, -1)
        t = self.pos_token[:N].expand(-1, B, -1)
        x = torch.cat([x, t], dim=-1)
        x = torch.relu(self.dec_fc_in(x))

        x = self.dec_transformer(x, src_key_padding_mask=padding_mask)
        # x = x.permute(1, 0, 2)[~padding_mask]
        x = x.permute(1, 0, 2)

        # logit_cls: [B, N, L]    bbox_pred: [B, N, 4]
        logit_cls = self.fc_out_cls(x)
        bbox_pred = torch.sigmoid(self.fc_out_bbox(x))

        return LayoutDmFIDNetV3Output(
            logit_disc=logit_disc, logit_cls=logit_cls, bbox_pred=bbox_pred
        )