shunk031 commited on
Commit
960dfdb
1 Parent(s): 746caf4

Upload LayoutDmFIDNetV3

Browse files
config.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LayoutDmFIDNetV3"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_fidnet_v3.LayoutDmFIDNetV3Config",
7
+ "AutoModel": "modeling_fidnet_v3.LayoutDmFIDNetV3"
8
+ },
9
+ "d_model": 256,
10
+ "id2label": {
11
+ "0": "Text",
12
+ "1": "Image",
13
+ "2": "Icon",
14
+ "3": "Text Button",
15
+ "4": "List Item",
16
+ "5": "Input",
17
+ "6": "Background Image",
18
+ "7": "Card",
19
+ "8": "Web View",
20
+ "9": "Radio Button",
21
+ "10": "Drawer",
22
+ "11": "Checkbox",
23
+ "12": "Advertisement",
24
+ "13": "Modal",
25
+ "14": "Pager Indicator",
26
+ "15": "Slider",
27
+ "16": "On/Off Switch",
28
+ "17": "Button Bar",
29
+ "18": "Toolbar",
30
+ "19": "Number Stepper",
31
+ "20": "Multi-Tab",
32
+ "21": "Date Picker",
33
+ "22": "Map View",
34
+ "23": "Video",
35
+ "24": "Bottom Navigation"
36
+ },
37
+ "label2id": {
38
+ "Advertisement": 12,
39
+ "Background Image": 6,
40
+ "Bottom Navigation": 24,
41
+ "Button Bar": 17,
42
+ "Card": 7,
43
+ "Checkbox": 11,
44
+ "Date Picker": 21,
45
+ "Drawer": 10,
46
+ "Icon": 2,
47
+ "Image": 1,
48
+ "Input": 5,
49
+ "List Item": 4,
50
+ "Map View": 22,
51
+ "Modal": 13,
52
+ "Multi-Tab": 20,
53
+ "Number Stepper": 19,
54
+ "On/Off Switch": 16,
55
+ "Pager Indicator": 14,
56
+ "Radio Button": 9,
57
+ "Slider": 15,
58
+ "Text": 0,
59
+ "Text Button": 3,
60
+ "Toolbar": 18,
61
+ "Video": 23,
62
+ "Web View": 8
63
+ },
64
+ "max_bbox": 25,
65
+ "model_type": "layoutdm_fidnet_v3",
66
+ "nhead": 4,
67
+ "num_layers": 4,
68
+ "torch_dtype": "float32",
69
+ "transformers_version": "4.36.2"
70
+ }
configuration_fidnet_v3.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class LayoutDmFIDNetV3Config(PretrainedConfig):
5
+ model_type = "layoutdm_fidnet_v3"
6
+
7
+ def __init__(
8
+ self,
9
+ d_model: int = 256,
10
+ nhead: int = 4,
11
+ num_layers: int = 4,
12
+ max_bbox: int = 50,
13
+ **kwargs
14
+ ):
15
+ super().__init__(**kwargs)
16
+ self.d_model = d_model
17
+ self.nhead = nhead
18
+ self.num_layers = num_layers
19
+ self.max_bbox = max_bbox
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10773c114017c5a7f038ecde3615476d9eb72bbf39bff6c1666193b519af4667
3
+ size 11714417
modeling_fidnet_v3.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers.modeling_utils import PreTrainedModel
6
+
7
+ from .configuration_fidnet_v3 import LayoutDmFIDNetV3Config
8
+
9
+
10
+ @dataclass
11
+ class LayoutDmFIDNetV3Output(object):
12
+ logit_dict: torch.Tensor
13
+ logit_cls: torch.Tensor
14
+ bbox_pred: torch.Tensor
15
+
16
+
17
+ class TransformerWithToken(nn.Module):
18
+ def __init__(self, d_model: int, nhead: int, dim_feedforward: int, num_layers: int):
19
+ super().__init__()
20
+
21
+ self.token = nn.Parameter(torch.randn(1, 1, d_model))
22
+ token_mask = torch.zeros(1, 1, dtype=torch.bool)
23
+ self.register_buffer("token_mask", token_mask)
24
+
25
+ self.core = nn.TransformerEncoder(
26
+ nn.TransformerEncoderLayer(
27
+ d_model=d_model,
28
+ nhead=nhead,
29
+ dim_feedforward=dim_feedforward,
30
+ ),
31
+ num_layers=num_layers,
32
+ )
33
+
34
+ def forward(self, x, src_key_padding_mask):
35
+ # x: [N, B, E]
36
+ # padding_mask: [B, N]
37
+ # `False` for valid values
38
+ # `True` for padded values
39
+
40
+ B = x.size(1)
41
+
42
+ token = self.token.expand(-1, B, -1)
43
+ x = torch.cat([token, x], dim=0)
44
+
45
+ token_mask = self.token_mask.expand(B, -1)
46
+ padding_mask = torch.cat([token_mask, src_key_padding_mask], dim=1)
47
+
48
+ x = self.core(x, src_key_padding_mask=padding_mask)
49
+
50
+ return x
51
+
52
+
53
+ class LayoutDmFIDNetV3(PreTrainedModel):
54
+ config_class = LayoutDmFIDNetV3Config
55
+
56
+ def __init__(self, config: LayoutDmFIDNetV3Config):
57
+ super().__init__(config)
58
+ self.config = config
59
+
60
+ # encoder
61
+ self.emb_label = nn.Embedding(config.num_labels, config.d_model)
62
+ self.fc_bbox = nn.Linear(4, config.d_model)
63
+ self.enc_fc_in = nn.Linear(config.d_model * 2, config.d_model)
64
+
65
+ self.enc_transformer = TransformerWithToken(
66
+ d_model=config.d_model,
67
+ dim_feedforward=config.d_model // 2,
68
+ nhead=config.nhead,
69
+ num_layers=config.num_layers,
70
+ )
71
+
72
+ self.fc_out_disc = nn.Linear(config.d_model, 1)
73
+
74
+ # decoder
75
+ self.pos_token = nn.Parameter(torch.rand(config.max_bbox, 1, config.d_model))
76
+ self.dec_fc_in = nn.Linear(config.d_model * 2, config.d_model)
77
+
78
+ te = nn.TransformerEncoderLayer(
79
+ d_model=config.d_model,
80
+ nhead=config.nhead,
81
+ dim_feedforward=config.d_model // 2,
82
+ )
83
+ self.dec_transformer = nn.TransformerEncoder(te, num_layers=config.num_layers)
84
+
85
+ self.fc_out_cls = nn.Linear(config.d_model, config.num_labels)
86
+ self.fc_out_bbox = nn.Linear(config.d_model, 4)
87
+
88
+ def extract_features(self, bbox, label, padding_mask):
89
+ b = self.fc_bbox(bbox)
90
+ l = self.emb_label(label)
91
+ x = self.enc_fc_in(torch.cat([b, l], dim=-1))
92
+ x = torch.relu(x).permute(1, 0, 2)
93
+ x = self.enc_transformer(x, padding_mask)
94
+ return x[0]
95
+
96
+ def forward(self, bbox, label, padding_mask):
97
+ B, N, _ = bbox.size()
98
+ x = self.extract_features(bbox, label, padding_mask)
99
+
100
+ logit_disc = self.fc_out_disc(x).squeeze(-1)
101
+
102
+ x = x.unsqueeze(0).expand(N, -1, -1)
103
+ t = self.pos_token[:N].expand(-1, B, -1)
104
+ x = torch.cat([x, t], dim=-1)
105
+ x = torch.relu(self.dec_fc_in(x))
106
+
107
+ x = self.dec_transformer(x, src_key_padding_mask=padding_mask)
108
+ # x = x.permute(1, 0, 2)[~padding_mask]
109
+ x = x.permute(1, 0, 2)
110
+
111
+ # logit_cls: [B, N, L] bbox_pred: [B, N, 4]
112
+ logit_cls = self.fc_out_cls(x)
113
+ bbox_pred = torch.sigmoid(self.fc_out_bbox(x))
114
+
115
+ return LayoutDmFIDNetV3Output(
116
+ logit_disc=logit_disc, logit_cls=logit_cls, bbox_pred=bbox_pred
117
+ )