hvaldez commited on
Commit
a17aefb
1 Parent(s): d9fc6d8

first commit

Browse files
Files changed (44) hide show
  1. app.py +86 -0
  2. ckpt/svitt-ego.pth +3 -0
  3. configs/base.yml +21 -0
  4. configs/config_bert.json +21 -0
  5. configs/ego_mcq/multiple-choice-question.yaml +37 -0
  6. configs/ego_mcq/svitt.yml +9 -0
  7. data/svitt-ego-demo/0/meta.json +1 -0
  8. data/svitt-ego-demo/0/tensors.pt +3 -0
  9. data/svitt-ego-demo/0/video/014b473f-aec0-49c7-b394-abc7309ca3c7-converted.mp4 +0 -0
  10. data/svitt-ego-demo/0/video/0a3097fc-baed-4d11-a4c9-30f07eb91af6-converted.mp4 +0 -0
  11. data/svitt-ego-demo/0/video/1a870d5d-5787-4098-ad8d-fe7343c43698-converted.mp4 +0 -0
  12. data/svitt-ego-demo/0/video/2d560d56-dc47-4c76-8d41-889c8aa55d66-converted.mp4 +0 -0
  13. data/svitt-ego-demo/0/video/eb5cb2b0-59e6-45da-af1b-ba86c7ab0b54-converted.mp4 +0 -0
  14. data/svitt-ego-demo/1/meta.json +1 -0
  15. data/svitt-ego-demo/1/tensors.pt +3 -0
  16. data/svitt-ego-demo/1/video/029eeb9a-8853-48a4-a1dc-e8868b58adf3-converted.mp4 +0 -0
  17. data/svitt-ego-demo/1/video/060e07d8-e818-4f9c-9d6b-6504f5fd42a3-converted.mp4 +0 -0
  18. data/svitt-ego-demo/1/video/53da674a-089d-428a-a719-e322b2de002b-converted.mp4 +0 -0
  19. data/svitt-ego-demo/1/video/968139e2-987e-4615-a2d4-fa2e683bae8a-converted.mp4 +0 -0
  20. data/svitt-ego-demo/1/video/fb9fda68-f264-465d-9208-19876f5ef90f-converted.mp4 +0 -0
  21. data/svitt-ego-demo/2/meta.json +1 -0
  22. data/svitt-ego-demo/2/tensors.pt +3 -0
  23. data/svitt-ego-demo/2/video/5f6f87ea-e1c3-4868-bb60-22c9e874d056-converted.mp4 +0 -0
  24. data/svitt-ego-demo/2/video/77718528-2de9-48b4-b6b8-e7c602032afb-converted.mp4 +0 -0
  25. data/svitt-ego-demo/2/video/8d83478f-c5d2-4ac3-a823-e1b2ac7594d7-converted.mp4 +0 -0
  26. data/svitt-ego-demo/2/video/9abbf7f4-68f0-4f52-812f-df2a3df48f7b-converted.mp4 +0 -0
  27. data/svitt-ego-demo/2/video/fa2f1291-3796-41a6-8f7b-6e7c1491b9b2-converted.mp4 +0 -0
  28. data/svitt-ego-demo/3/meta.json +1 -0
  29. data/svitt-ego-demo/3/tensors.pt +3 -0
  30. data/svitt-ego-demo/3/video/2a6b3d10-8da9-4f0e-a681-59ba48a55dbf-converted.mp4 +0 -0
  31. data/svitt-ego-demo/3/video/5afd7421-fb6b-4c65-a09a-716f79a7a935-converted.mp4 +0 -0
  32. data/svitt-ego-demo/3/video/81fff27c-97c0-483a-ad42-47fa947977a9-converted.mp4 +0 -0
  33. data/svitt-ego-demo/3/video/84d6855a-242b-44a6-b48d-2db302b5ea7a-converted.mp4 +0 -0
  34. data/svitt-ego-demo/3/video/f7aec252-bd4f-4696-8de5-ef7b871e2194-converted.mp4 +0 -0
  35. demo.py +165 -0
  36. requirements.txt +14 -0
  37. svitt/base_dataset.py +56 -0
  38. svitt/config.py +36 -0
  39. svitt/model.py +340 -0
  40. svitt/sparse_config.py +351 -0
  41. svitt/sparse_xbeit.py +1585 -0
  42. svitt/sparse_xbert.py +2039 -0
  43. svitt/tokenization_bert.py +546 -0
  44. svitt/utils.py +235 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from demo import VideoCLSModel
4
+
5
+
6
+ sample_videos = [
7
+ [
8
+ "data/svitt-ego-demo/0/video/2d560d56-dc47-4c76-8d41-889c8aa55d66-converted.mp4",
9
+ "data/svitt-ego-demo/0/video/eb5cb2b0-59e6-45da-af1b-ba86c7ab0b54-converted.mp4",
10
+ "data/svitt-ego-demo/0/video/0a3097fc-baed-4d11-a4c9-30f07eb91af6-converted.mp4",
11
+ "data/svitt-ego-demo/0/video/1a870d5d-5787-4098-ad8d-fe7343c43698-converted.mp4",
12
+ "data/svitt-ego-demo/0/video/014b473f-aec0-49c7-b394-abc7309ca3c7-converted.mp4",
13
+ ],
14
+ [
15
+ "data/svitt-ego-demo/1/video/029eeb9a-8853-48a4-a1dc-e8868b58adf3-converted.mp4",
16
+ "data/svitt-ego-demo/1/video/968139e2-987e-4615-a2d4-fa2e683bae8a-converted.mp4",
17
+ "data/svitt-ego-demo/1/video/fb9fda68-f264-465d-9208-19876f5ef90f-converted.mp4",
18
+ "data/svitt-ego-demo/1/video/53da674a-089d-428a-a719-e322b2de002b-converted.mp4",
19
+ "data/svitt-ego-demo/1/video/060e07d8-e818-4f9c-9d6b-6504f5fd42a3-converted.mp4",
20
+ ],
21
+ [
22
+ "data/svitt-ego-demo/2/video/fa2f1291-3796-41a6-8f7b-6e7c1491b9b2-converted.mp4",
23
+ "data/svitt-ego-demo/2/video/8d83478f-c5d2-4ac3-a823-e1b2ac7594d7-converted.mp4",
24
+ "data/svitt-ego-demo/2/video/5f6f87ea-e1c3-4868-bb60-22c9e874d056-converted.mp4",
25
+ "data/svitt-ego-demo/2/video/77718528-2de9-48b4-b6b8-e7c602032afb-converted.mp4",
26
+ "data/svitt-ego-demo/2/video/9abbf7f4-68f0-4f52-812f-df2a3df48f7b-converted.mp4",
27
+ ],
28
+ [
29
+ "data/svitt-ego-demo/3/video/2a6b3d10-8da9-4f0e-a681-59ba48a55dbf-converted.mp4",
30
+ "data/svitt-ego-demo/3/video/5afd7421-fb6b-4c65-a09a-716f79a7a935-converted.mp4",
31
+ "data/svitt-ego-demo/3/video/f7aec252-bd4f-4696-8de5-ef7b871e2194-converted.mp4",
32
+ "data/svitt-ego-demo/3/video/84d6855a-242b-44a6-b48d-2db302b5ea7a-converted.mp4",
33
+ "data/svitt-ego-demo/3/video/81fff27c-97c0-483a-ad42-47fa947977a9-converted.mp4",
34
+ ],
35
+ ]
36
+ sample_text = [
37
+ "drops the palm fronds on the ground",
38
+ "stands up",
39
+ "throws nuts in a bowl",
40
+ "puts the speaker and notepad in both hands on a seat",
41
+ ]
42
+ sample_text_dict = {
43
+ "drops the palm fronds on the ground": 0,
44
+ "stands up": 1,
45
+ "throws nuts in a bowl": 2,
46
+ "puts the speaker and notepad in both hands on a seat": 3,
47
+ }
48
+ num_samples = len(sample_videos[0])
49
+ labels = [f"video-{i}" for i in range(num_samples)]
50
+
51
+ def main():
52
+ svitt = VideoCLSModel(
53
+ "configs/ego_mcq/svitt.yml",
54
+ sample_videos,
55
+ )
56
+ def predict(text):
57
+ idx = sample_text_dict[text]
58
+ ft_action, gt_action = svitt.predict(idx, text)
59
+ return labels[gt_action], labels[ft_action]
60
+
61
+ with gr.Blocks() as demo:
62
+ gr.Markdown(
63
+ """
64
+ # SViTT-Ego for Multiple Choice Question
65
+ Choose a sample query and click predict to view the results.
66
+ """
67
+ )
68
+ with gr.Row():
69
+ with gr.Column():
70
+ videos = [gr.Video(label=labels[i], format='mp4', height=256, min_width=340) for i in range(num_samples)]
71
+ with gr.Column():
72
+ text = gr.Text(label="Query", visible=False)
73
+ label = gr.Text(label="Ground Truth")
74
+ ours = gr.Text(label="SViTT-Ego prediction")
75
+ btn = gr.Button("Predict", variant="primary")
76
+ btn.click(predict, inputs=[text], outputs=[label, ours])
77
+ inputs = [text]
78
+ inputs.extend(videos)
79
+ gr.Examples(examples=[[sample_text[i], x[0], x[1], x[2], x[3], x[4]] for i, x in enumerate(sample_videos)], inputs=inputs)
80
+
81
+ demo.launch(share=True)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
86
+
ckpt/svitt-ego.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d66ad807fb618b1e99da476d54238555eb51925afa65e444ba43dc5c235db1e
3
+ size 2500535422
configs/base.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ pretrain: ""
3
+ resume: ""
4
+ timesformer_freeze_space: false
5
+ drop_path_rate: 0.1
6
+ dropout_ratio: 0.5
7
+ freeze_vis_backbone: false
8
+ freeze_txt_backbone: false
9
+ use_vn_classifier: false
10
+
11
+ data:
12
+ dataset: ek100_mir
13
+ root: datasets/EK100/video_ht256px
14
+ metadata: datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_train.csv
15
+ metadata_val: datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test.csv
16
+ relevancy_path: datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl
17
+ clip_length: 16
18
+ clip_stride: 4
19
+ sparse_sample: false
20
+ num_crops: 1
21
+ num_clips: 1
configs/config_bert.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30522,
19
+ "fusion_layer": 9,
20
+ "encoder_width": 768
21
+ }
configs/ego_mcq/multiple-choice-question.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ text_encoder: bert-base-uncased
3
+ bert_config: configs/config_bert.json
4
+ vit_type: beit # items in ${vit_zoo}
5
+ vit_zoo: # from huggingface
6
+ beit: microsoft/beit-base-patch16-224-pt22k-ft22k
7
+ vit_name_or_pretrained_path: ${vit_zoo[${vit_type}]}
8
+
9
+ vision_encoder_args:
10
+ token_keep_rate: 0.7
11
+ token_keep_strategy: cls_attn
12
+ token_drop_loc: [3, 6, 9]
13
+ sparse_local_attn: 1
14
+ sparse_random_attn: 5
15
+ attn_block_size: 56
16
+
17
+ image_res: 224
18
+ embed_dim: 256
19
+ video_input:
20
+ num_frames: 4
21
+ reader: decord # one of [decord, av]
22
+ sample_type: rand
23
+ num_frames_test: 16 # num_frames during inference/test
24
+ sample_type_test: middle
25
+ max_txt_l:
26
+ image: 32
27
+ video: 32
28
+
29
+ batch_size:
30
+ image: 8
31
+ video: 8
32
+ batch_size_test:
33
+ image: 8
34
+ video: 8
35
+ k_test: 128
36
+ temp: 0.18
37
+ mlm_prob: 0.5
configs/ego_mcq/svitt.yml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ pretrain: ckpt/svitt-ego.pth
3
+ freeze_vis_backbone: true
4
+ freeze_txt_backbone: true
5
+ num_frames: 4
6
+ config: configs/ego_mcq/multiple-choice-question.yaml
7
+
8
+ data:
9
+ root: data/svitt-ego-demo
data/svitt-ego-demo/0/meta.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"text": "#C C drops the palm fronds on the ground", "text_ops": ["#C C picks a bowl with fruit from person O", "#C C turns to the woman X on his right", "#C C picks the pocket knife ", "#C C removes the paint from the wall.", "#C C drops the palm fronds on the ground"], "correct": 4, "type": 1, "meta": {"raw_captions": "#C C drops the palm fronds on the ground", "paths": [["/datasets/ego4d/egovlp/full_scale_256_chunked/2d560d56-dc47-4c76-8d41-889c8aa55d66/4.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/2d560d56-dc47-4c76-8d41-889c8aa55d66/4.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/eb5cb2b0-59e6-45da-af1b-ba86c7ab0b54/0.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/eb5cb2b0-59e6-45da-af1b-ba86c7ab0b54/0.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/0a3097fc-baed-4d11-a4c9-30f07eb91af6/0.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/0a3097fc-baed-4d11-a4c9-30f07eb91af6/0.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/1a870d5d-5787-4098-ad8d-fe7343c43698/0.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/1a870d5d-5787-4098-ad8d-fe7343c43698/0.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/014b473f-aec0-49c7-b394-abc7309ca3c7/0.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/014b473f-aec0-49c7-b394-abc7309ca3c7/0.mp4"]]}}
data/svitt-ego-demo/0/tensors.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a572098fa6fb25b4a465dc138a3341f6b22cfc525f7e772b53bc17bd171e56d
3
+ size 12042987
data/svitt-ego-demo/0/video/014b473f-aec0-49c7-b394-abc7309ca3c7-converted.mp4 ADDED
Binary file (300 kB). View file
 
data/svitt-ego-demo/0/video/0a3097fc-baed-4d11-a4c9-30f07eb91af6-converted.mp4 ADDED
Binary file (42.5 kB). View file
 
data/svitt-ego-demo/0/video/1a870d5d-5787-4098-ad8d-fe7343c43698-converted.mp4 ADDED
Binary file (39.1 kB). View file
 
data/svitt-ego-demo/0/video/2d560d56-dc47-4c76-8d41-889c8aa55d66-converted.mp4 ADDED
Binary file (164 kB). View file
 
data/svitt-ego-demo/0/video/eb5cb2b0-59e6-45da-af1b-ba86c7ab0b54-converted.mp4 ADDED
Binary file (96.8 kB). View file
 
data/svitt-ego-demo/1/meta.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"text": "#C C stands up", "text_ops": ["#C C walks a round", "#C C stands up", "#O person Y pushes the door ", "#C C picks pastry cloth", "#C C holds the wire"], "correct": 1, "type": 1, "meta": {"raw_captions": "#C C holds the wire", "paths": [["/datasets/ego4d/egovlp/full_scale_256_chunked/029eeb9a-8853-48a4-a1dc-e8868b58adf3/0.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/029eeb9a-8853-48a4-a1dc-e8868b58adf3/0.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/968139e2-987e-4615-a2d4-fa2e683bae8a/4.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/968139e2-987e-4615-a2d4-fa2e683bae8a/4.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/fb9fda68-f264-465d-9208-19876f5ef90f/0.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/fb9fda68-f264-465d-9208-19876f5ef90f/0.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/53da674a-089d-428a-a719-e322b2de002b/1.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/53da674a-089d-428a-a719-e322b2de002b/1.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/060e07d8-e818-4f9c-9d6b-6504f5fd42a3/0.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/060e07d8-e818-4f9c-9d6b-6504f5fd42a3/0.mp4"]]}}
data/svitt-ego-demo/1/tensors.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:053afd65ae1250bfc609d34c561c51fdde4812251b233f2f246f68b66da84533
3
+ size 12042987
data/svitt-ego-demo/1/video/029eeb9a-8853-48a4-a1dc-e8868b58adf3-converted.mp4 ADDED
Binary file (60.1 kB). View file
 
data/svitt-ego-demo/1/video/060e07d8-e818-4f9c-9d6b-6504f5fd42a3-converted.mp4 ADDED
Binary file (42.8 kB). View file
 
data/svitt-ego-demo/1/video/53da674a-089d-428a-a719-e322b2de002b-converted.mp4 ADDED
Binary file (29.3 kB). View file
 
data/svitt-ego-demo/1/video/968139e2-987e-4615-a2d4-fa2e683bae8a-converted.mp4 ADDED
Binary file (39.6 kB). View file
 
data/svitt-ego-demo/1/video/fb9fda68-f264-465d-9208-19876f5ef90f-converted.mp4 ADDED
Binary file (161 kB). View file
 
data/svitt-ego-demo/2/meta.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"text": "#C C throws nuts in a bowl.", "text_ops": ["#C C throws nuts in a bowl.", "#O The woman Z places her right hand on the table.", "#O The woman T touches a card on the table with her left hand.", "#C C joins the pieces of dough together on the tray.", "#O A woman X walks forward"], "correct": 0, "type": 1, "meta": {"raw_captions": "#O A woman X walks forward", "paths": [["/datasets/ego4d/egovlp/full_scale_256_chunked/fa2f1291-3796-41a6-8f7b-6e7c1491b9b2/0.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/fa2f1291-3796-41a6-8f7b-6e7c1491b9b2/0.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/8d83478f-c5d2-4ac3-a823-e1b2ac7594d7/1.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/8d83478f-c5d2-4ac3-a823-e1b2ac7594d7/1.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/5f6f87ea-e1c3-4868-bb60-22c9e874d056/0.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/5f6f87ea-e1c3-4868-bb60-22c9e874d056/0.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/77718528-2de9-48b4-b6b8-e7c602032afb/4.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/77718528-2de9-48b4-b6b8-e7c602032afb/4.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/9abbf7f4-68f0-4f52-812f-df2a3df48f7b/1.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/9abbf7f4-68f0-4f52-812f-df2a3df48f7b/1.mp4"]]}}
data/svitt-ego-demo/2/tensors.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:381ef1a03e5685229295a3a6a327ef39b0489b0da5fe9ba3754c6101f6c39ec6
3
+ size 12042987
data/svitt-ego-demo/2/video/5f6f87ea-e1c3-4868-bb60-22c9e874d056-converted.mp4 ADDED
Binary file (20.6 kB). View file
 
data/svitt-ego-demo/2/video/77718528-2de9-48b4-b6b8-e7c602032afb-converted.mp4 ADDED
Binary file (27.2 kB). View file
 
data/svitt-ego-demo/2/video/8d83478f-c5d2-4ac3-a823-e1b2ac7594d7-converted.mp4 ADDED
Binary file (47.6 kB). View file
 
data/svitt-ego-demo/2/video/9abbf7f4-68f0-4f52-812f-df2a3df48f7b-converted.mp4 ADDED
Binary file (52.7 kB). View file
 
data/svitt-ego-demo/2/video/fa2f1291-3796-41a6-8f7b-6e7c1491b9b2-converted.mp4 ADDED
Binary file (26.1 kB). View file
 
data/svitt-ego-demo/3/meta.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"text": "#C C puts the speaker and notepad in both hands on a seat.", "text_ops": ["#C C picks a chaff from the pan of ingredients", "#C C switches his left hand grip on the broom", "#C C cuts the dough on the tray with the scraper in his right hand.", "#C C pulls the wire mesh.", "#C C puts the speaker and notepad in both hands on a seat."], "correct": 4, "type": 1, "meta": {"raw_captions": "#C C puts the speaker and notepad in both hands on a seat.", "paths": [["/datasets/ego4d/egovlp/full_scale_256_chunked/2a6b3d10-8da9-4f0e-a681-59ba48a55dbf/2.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/2a6b3d10-8da9-4f0e-a681-59ba48a55dbf/2.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/5afd7421-fb6b-4c65-a09a-716f79a7a935/1.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/5afd7421-fb6b-4c65-a09a-716f79a7a935/1.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/f7aec252-bd4f-4696-8de5-ef7b871e2194/1.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/f7aec252-bd4f-4696-8de5-ef7b871e2194/1.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/84d6855a-242b-44a6-b48d-2db302b5ea7a/0.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/84d6855a-242b-44a6-b48d-2db302b5ea7a/0.mp4"], ["/datasets/ego4d/egovlp/full_scale_256_chunked/81fff27c-97c0-483a-ad42-47fa947977a9/9.mp4", "/datasets/ego4d/egovlp/full_scale_256_chunked/81fff27c-97c0-483a-ad42-47fa947977a9/9.mp4"]]}}
data/svitt-ego-demo/3/tensors.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ace4bcc2911dc78fc7ef6802c9c2b392207e42811e7449e628adf39ffa5f84c
3
+ size 12042987
data/svitt-ego-demo/3/video/2a6b3d10-8da9-4f0e-a681-59ba48a55dbf-converted.mp4 ADDED
Binary file (40.5 kB). View file
 
data/svitt-ego-demo/3/video/5afd7421-fb6b-4c65-a09a-716f79a7a935-converted.mp4 ADDED
Binary file (40.2 kB). View file
 
data/svitt-ego-demo/3/video/81fff27c-97c0-483a-ad42-47fa947977a9-converted.mp4 ADDED
Binary file (135 kB). View file
 
data/svitt-ego-demo/3/video/84d6855a-242b-44a6-b48d-2db302b5ea7a-converted.mp4 ADDED
Binary file (38.7 kB). View file
 
data/svitt-ego-demo/3/video/f7aec252-bd4f-4696-8de5-ef7b871e2194-converted.mp4 ADDED
Binary file (16.1 kB). View file
 
demo.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### demo.py
2
+ # Define model classes for inference.
3
+ ###
4
+ import json
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.backends.cudnn as cudnn
8
+ from einops import rearrange
9
+ from transformers import BertTokenizer
10
+ from torchvision import transforms
11
+ from torchvision.transforms._transforms_video import (
12
+ NormalizeVideo,
13
+ )
14
+
15
+ from svitt.model import SViTT
16
+ from svitt.config import load_cfg, setup_config
17
+ from svitt.base_dataset import read_frames_cv2_egoclip
18
+
19
+
20
+ class VideoModel(nn.Module):
21
+ """ Base model for video understanding based on SViTT architecture. """
22
+ def __init__(self, config):
23
+ """ Initializes the model.
24
+ Parameters:
25
+ config: config file
26
+ """
27
+ super(VideoModel, self).__init__()
28
+ self.cfg = load_cfg(config)
29
+ self.model = self.build_model()
30
+ self.templates = ['{}']
31
+ self.dataset = self.cfg['data']['dataset']
32
+ self.eval()
33
+
34
+ def build_model(self):
35
+ cfg = self.cfg
36
+ if cfg['model'].get('pretrain', False):
37
+ ckpt_path = cfg['model']['pretrain']
38
+ else:
39
+ raise Exception('no checkpoint found')
40
+
41
+ if cfg['model'].get('config', False):
42
+ config_path = cfg['model']['config']
43
+ else:
44
+ raise Exception('no model config found')
45
+
46
+ self.model_cfg = setup_config(config_path)
47
+ self.tokenizer = BertTokenizer.from_pretrained(self.model_cfg.text_encoder)
48
+ model = SViTT(config=self.model_cfg, tokenizer=self.tokenizer)
49
+
50
+ print(f"Loading checkpoint from {ckpt_path}")
51
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
52
+ state_dict = checkpoint["model"]
53
+
54
+ # fix for zero-shot evaluation
55
+ for key in list(state_dict.keys()):
56
+ if "bert" in key:
57
+ encoder_key = key.replace("bert.", "")
58
+ state_dict[encoder_key] = state_dict[key]
59
+
60
+ if torch.cuda.is_available():
61
+ model.cuda()
62
+
63
+ model.load_state_dict(state_dict, strict=False)
64
+
65
+ return model
66
+
67
+ def eval(self):
68
+ cudnn.benchmark = True
69
+ for p in self.model.parameters():
70
+ p.requires_grad = False
71
+ self.model.eval()
72
+
73
+
74
+ class VideoCLSModel(VideoModel):
75
+ """ Video model for video classification tasks (Charades-Ego, EGTEA). """
76
+ def __init__(self, config, sample_videos):
77
+ super(VideoCLSModel, self).__init__(config)
78
+ self.sample_videos = sample_videos
79
+ self.video_transform = self.init_video_transform()
80
+
81
+ #def load_data(self, idx=None):
82
+ # filename = f"{self.cfg['data']['root']}/{idx}/tensors.pt"
83
+ # return torch.load(filename)
84
+ def init_video_transform(self,
85
+ input_res=224,
86
+ center_crop=256,
87
+ norm_mean=(0.485, 0.456, 0.406),
88
+ norm_std=(0.229, 0.224, 0.225),
89
+ ):
90
+ print('Video Transform is used!')
91
+ normalize = NormalizeVideo(mean=norm_mean, std=norm_std)
92
+ return transforms.Compose(
93
+ [
94
+ transforms.Resize(center_crop),
95
+ transforms.CenterCrop(center_crop),
96
+ transforms.Resize(input_res),
97
+ normalize,
98
+ ]
99
+ )
100
+
101
+ def load_data(self, idx):
102
+ num_frames = self.model_cfg.video_input.num_frames
103
+ video_paths = self.sample_videos[idx]
104
+ clips = [None] * len(video_paths)
105
+ for i, path in enumerate(video_paths):
106
+ imgs = read_frames_cv2_egoclip(path, num_frames, 'uniform')
107
+ imgs = imgs.transpose(0, 1)
108
+ imgs = self.video_transform(imgs)
109
+ imgs = imgs.transpose(0, 1)
110
+ clips[i] = imgs
111
+ return torch.stack(clips)
112
+
113
+ def load_meta(self, idx=None):
114
+ filename = f"{self.cfg['data']['root']}/{idx}/meta.json"
115
+ with open(filename, "r") as f:
116
+ meta = json.load(f)
117
+ return meta
118
+
119
+ @torch.no_grad()
120
+ def get_text_features(self, text):
121
+ print('=> Extracting text features')
122
+ embeddings = self.tokenizer(
123
+ text,
124
+ padding="max_length",
125
+ truncation=True,
126
+ max_length=self.model_cfg.max_txt_l.video,
127
+ return_tensors="pt",
128
+ )
129
+ _, class_embeddings = self.model.encode_text(embeddings)
130
+ return class_embeddings
131
+
132
+ @torch.no_grad()
133
+ def forward(self, idx, text=None):
134
+ print('=> Start forwarding')
135
+ meta = self.load_meta(idx)
136
+ clips = self.load_data(idx)
137
+ if text is None:
138
+ text = meta["text"][4:]
139
+ text_features = self.get_text_features(text)
140
+ target = meta["correct"]
141
+
142
+ # encode images
143
+ pooled_image_feat_all = []
144
+ for i in range(clips.shape[0]):
145
+
146
+ images = clips[i,:].unsqueeze(0)
147
+ bsz = images.shape[0]
148
+
149
+ _, pooled_image_feat, *outputs = self.model.encode_image(images)
150
+ if pooled_image_feat.ndim == 3:
151
+ pooled_image_feat = rearrange(pooled_image_feat, '(b k) n d -> b (k n) d', b=bsz)
152
+ else:
153
+ pooled_image_feat = rearrange(pooled_image_feat, '(b k) d -> b k d', b=bsz)
154
+
155
+ pooled_image_feat_all.append(pooled_image_feat)
156
+
157
+ pooled_image_feat_all = torch.cat(pooled_image_feat_all, dim=0)
158
+ similarity = self.model.get_sim(pooled_image_feat_all, text_features)[0]
159
+ return similarity.argmax(), target
160
+
161
+ @torch.no_grad()
162
+ def predict(self, idx, text=None):
163
+ output, target = self.forward(idx, text)
164
+ return output.numpy(), target
165
+
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ scikit-learn
5
+ eva-decord
6
+ timm
7
+ einops
8
+ ftfy
9
+ regex
10
+ transformers
11
+ omegaconf
12
+ zCurve
13
+ numpy-hilbert-curve
14
+ opencv-python-headless
svitt/base_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import cv2
4
+ import torch
5
+ import numpy as np
6
+
7
+
8
+ def sample_frames_start_end(num_frames, start, end, sample='rand', fix_start=None):
9
+ acc_samples = min(num_frames, end)
10
+ intervals = np.linspace(start=start, stop=end, num=acc_samples + 1).astype(int)
11
+ ranges = [(interv, intervals[idx + 1] - 1) for idx, interv in enumerate(intervals[:-1])]
12
+ if sample == 'rand':
13
+ frame_idxs = [random.choice(range(x[0], x[1])) for x in ranges]
14
+ elif fix_start is not None:
15
+ frame_idxs = [x[0] + fix_start for x in ranges]
16
+ elif sample == 'uniform':
17
+ frame_idxs = [(x[0] + x[1]) // 2 for x in ranges]
18
+ else:
19
+ raise NotImplementedError
20
+ return frame_idxs
21
+
22
+ def read_frames_cv2_egoclip(
23
+ video_path,
24
+ num_frames,
25
+ sample,
26
+ ):
27
+
28
+ cap = cv2.VideoCapture(video_path)
29
+ vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
30
+ assert (cap.isOpened())
31
+
32
+ # get indexes of sampled frames
33
+ start_f = 0
34
+ end_f = vlen
35
+ frame_idxs = sample_frames_start_end(num_frames, start_f, end_f, sample=sample)
36
+
37
+ frames = []
38
+ for index in frame_idxs:
39
+ _index = index % (600 * 30)
40
+ _index = min(_index, vlen)
41
+ cap.set(cv2.CAP_PROP_POS_FRAMES, _index - 1)
42
+ ret, frame = cap.read()
43
+
44
+ if ret:
45
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
46
+ frame = torch.from_numpy(frame)
47
+ # (H x W x C) to (C x H x W)
48
+ frame = frame.permute(2, 0, 1)
49
+ frames.append(frame)
50
+
51
+ while len(frames) < num_frames: # complete the frame
52
+ frames.append(frames[-1])
53
+
54
+ frames = torch.stack(frames).float() / 255
55
+ cap.release()
56
+ return frames
svitt/config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import yaml
3
+ from omegaconf import OmegaConf, DictConfig
4
+
5
+ def load_base_cfg():
6
+ with open('configs/base.yml', 'r') as fp:
7
+ cfg = yaml.load(fp, Loader=yaml.SafeLoader)
8
+ return cfg
9
+
10
+ def load_cfg(cfg_file):
11
+ cfg = load_base_cfg()
12
+ with open(cfg_file, 'r') as fp:
13
+ exp_cfg = yaml.load(fp, Loader=yaml.SafeLoader)
14
+
15
+ cfg['model'].update(exp_cfg.get('model', {}))
16
+ cfg['data'].update(exp_cfg.get('data', {}))
17
+ return cfg
18
+
19
+ def convert_types(config):
20
+ """Convert `'None'` (str) --> `None` (None). Only supports top-level"""
21
+ for k, v in config.items():
22
+ if isinstance(v, DictConfig):
23
+ setattr(config, k, convert_types(v))
24
+
25
+ # TODO convert types in ListConfig, right now they are ignored
26
+ # if isinstance(v, ListConfig):
27
+ # new_v = ListConfig()
28
+
29
+ if v in ["None", "none"]:
30
+ setattr(config, k, None)
31
+ return config
32
+
33
+ def setup_config(config_path):
34
+ yaml_config = OmegaConf.load(config_path)
35
+ config = convert_types(yaml_config)
36
+ return config
svitt/model.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from svitt.utils import (
2
+ interpolate_pos_embed,
3
+ interpolate_pos_relative_bias_beit_3d,
4
+ )
5
+ from omegaconf import OmegaConf
6
+ from transformers import ViTModel, ViTConfig
7
+ from svitt.sparse_config import BertConfig, BeitConfig
8
+ from svitt.sparse_xbeit import BeitModel
9
+ from svitt.sparse_xbert import BertModel, BertForMaskedLM
10
+
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ class SViTT(nn.Module):
17
+ """Common utils shared by pretraining and downstream retrieval"""
18
+ def __init__(self, config=None, tokenizer=None, pretrain=True, **kwargs):
19
+ super().__init__()
20
+ self.config = config
21
+ self.tokenizer = tokenizer
22
+ self.embed_dim = config.embed_dim
23
+ self.vision_width = 768
24
+ self.text_width = 768
25
+ self.pretrain = pretrain
26
+
27
+ self.vision_encoder, self.vision_layernorm = self.build_vision_encoder()
28
+ self.text_encoder = self.build_text_encoder()
29
+
30
+ self.vision_proj = nn.Linear(self.vision_width, self.embed_dim)
31
+ self.text_proj = nn.Linear(self.text_width, self.embed_dim)
32
+
33
+ self.temp = nn.Parameter(torch.ones([]) * config.temp)
34
+ self.itm_head = nn.Linear(self.text_width, 2)
35
+
36
+
37
+ def build_text_encoder(self):
38
+
39
+ bert_config = BertConfig.from_json_file(self.config.bert_config)
40
+
41
+ # Override params for sparse vision encoder
42
+ model_args = getattr(self.config, 'text_encoder_args', {})
43
+ if model_args:
44
+ model_args = OmegaConf.to_object(model_args)
45
+ bert_config.update(model_args)
46
+
47
+ if self.pretrain:
48
+ text_encoder, _ = BertForMaskedLM.from_pretrained(
49
+ self.config.text_encoder, config=bert_config,
50
+ output_loading_info=True
51
+ )
52
+ else:
53
+ text_encoder, _ = BertModel.from_pretrained(
54
+ self.config.text_encoder, config=bert_config,
55
+ add_pooling_layer=False, output_loading_info=True
56
+ )
57
+ return text_encoder
58
+
59
+ def build_vision_encoder(self):
60
+ # if self.config.vit_type in ["beit", "deit", "vit", "vit32"]:
61
+ if self.config.vit_type in ["beit"]:
62
+ vision_encoder = self.build_huggingface_vit_with_image_size(
63
+ self.config.vit_name_or_pretrained_path, self.config.image_res,)
64
+ else:
65
+ raise ValueError(f"Unknown vit type {self.config.vit_type}")
66
+
67
+ # add layernorm for normalizing BEiT outputs hidden states
68
+ vision_layernorm = None
69
+ if self.config.vit_type == "beit":
70
+ vision_layernorm = nn.LayerNorm(self.vision_width, eps=1e-12)
71
+ return vision_encoder, vision_layernorm
72
+
73
+ # @classmethod
74
+ # def build_huggingface_vit_with_image_size(cls, model_card: str, image_size: int):
75
+ def build_huggingface_vit_with_image_size(self, model_card: str, image_size: int):
76
+ """Build a vit model from huggingface hub, also interpolate pos_embed when needed.
77
+
78
+ Args:
79
+ model_card: name in huggingface hub, e.g., `facebook/deit-base-patch16-224`
80
+ image_size: new image size, may be different from pre-training image_size of `model_card`
81
+
82
+ ref: https://github.com/huggingface/transformers/issues/12167#issuecomment-861356232
83
+ """
84
+ is_beit = "beit" in model_card
85
+ if "beit" in model_card:
86
+ model_cls, config_cls = BeitModel, BeitConfig
87
+ elif "deit" in model_card or "vit" in model_card:
88
+ # the deit model we use is loaded in vit arch,
89
+ # see https://huggingface.co/facebook/deit-base-patch16-224#how-to-use
90
+ model_cls, config_cls = ViTModel, ViTConfig
91
+ else:
92
+ raise ValueError(f"Unexpected model_card: {model_card}")
93
+
94
+ # BEiT uses average pooled tokens instead of [CLS] used by other models
95
+ tmp_model = model_cls.from_pretrained(model_card, add_pooling_layer=is_beit)
96
+ state_dict = tmp_model.state_dict()
97
+ del tmp_model
98
+
99
+ # Override params for sparse vision encoder
100
+ model_args = getattr(self.config, 'vision_encoder_args', {})
101
+ if model_args:
102
+ model_args = OmegaConf.to_object(model_args)
103
+ model_config = config_cls.from_pretrained(
104
+ model_card,
105
+ image_size=image_size,
106
+ **model_args,
107
+ )
108
+ model = model_cls(config=model_config, add_pooling_layer=is_beit, num_frames=self.config.video_input.num_frames)
109
+ if is_beit:
110
+ # interpolate relative pos bias
111
+ state_dict = interpolate_pos_relative_bias_beit_3d(
112
+ state_dict_old=state_dict,
113
+ state_dict_new=model.state_dict(),
114
+ patch_shape_new=model.window_size
115
+ )
116
+ else:
117
+ # interpolate pos_embed and load weights to new model
118
+ state_dict["embeddings.position_embeddings"] = interpolate_pos_embed(
119
+ pos_embed_old=state_dict["embeddings.position_embeddings"],
120
+ pos_embed_new=model.embeddings.position_embeddings,
121
+ num_patches_new=model.embeddings.patch_embeddings.num_patches
122
+ )
123
+ msg = model.load_state_dict(state_dict, strict=False)
124
+ return model
125
+
126
+ def get_text_encoder(self):
127
+ """get text encoder, used for text and cross-modal encoding"""
128
+ encoder = self.text_encoder
129
+ return encoder.bert if hasattr(encoder, "bert") else encoder
130
+
131
+ def encode_image(self, video, output_token_idx=False, output_attentions=False):
132
+ video_embeds = self.vision_encoder(video, output_token_idx=output_token_idx, output_attentions=output_attentions) # (bsz, seq_len, d)
133
+ if self.vision_layernorm is not None: # only for BEiT mean-pooling
134
+ video_embeds.last_hidden_state = self.vision_layernorm(video_embeds.last_hidden_state)
135
+ if output_token_idx:
136
+ token_idx = video_embeds.token_idx
137
+
138
+ if output_attentions:
139
+ attentions = video_embeds.attentions
140
+
141
+ if self.config.vit_type == "beit":
142
+ pooled_video_embeds = video_embeds.pooler_output # (bsz*num_frms, d)
143
+ video_embeds = video_embeds.last_hidden_state # (bsz*num_frms, L, d)
144
+ else:
145
+ video_embeds = video_embeds.last_hidden_state
146
+ pooled_video_embeds = video_embeds[:, 0]
147
+
148
+ outputs = (video_embeds, pooled_video_embeds)
149
+
150
+ if output_token_idx:
151
+ outputs += (token_idx,)
152
+
153
+ if output_attentions:
154
+ outputs += (attentions,)
155
+
156
+ return outputs
157
+
158
+ def _encode_image(self, image):
159
+ bsz, num_frms, c, h, w = image.shape # `num_frms` could be changing for image (=1) or video (e.g., =4)
160
+ image = image.view(bsz*num_frms, c, h, w)
161
+ image_embeds = self.vision_encoder(image)
162
+ if self.vision_layernorm is not None: # only for BEiT mean-pooling
163
+ image_embeds.last_hidden_state = self.vision_layernorm(image_embeds.last_hidden_state)
164
+
165
+ if self.config.vit_type == "beit":
166
+ pooled_image_embeds = image_embeds.pooler_output # (bsz*num_frms, d)
167
+ image_embeds = image_embeds.last_hidden_state # (bsz*num_frms, L, d)
168
+ else:
169
+ image_embeds = image_embeds.last_hidden_state
170
+ pooled_image_embeds = image_embeds[:, 0]
171
+
172
+ image_embeds = image_embeds.view(bsz, num_frms, -1, self.vision_width) # (bsz, num_frms, L, d)
173
+ pooled_image_embeds = pooled_image_embeds.view(bsz, num_frms, self.vision_width) \
174
+ if pooled_image_embeds is not None else None # (bsz, num_frms, d)
175
+ return image_embeds, pooled_image_embeds
176
+
177
+ def encode_text(self, text):
178
+ text_output = self.get_text_encoder()(
179
+ text.input_ids,
180
+ attention_mask=text.attention_mask,
181
+ return_dict=True,
182
+ mode='text'
183
+ )
184
+ text_embeds = text_output.last_hidden_state
185
+ pooled_text_embeds = text_embeds[:, 0]
186
+ return text_embeds, pooled_text_embeds
187
+
188
+ @torch.no_grad()
189
+ def clip_contrastive_temperature(self, min_val=0.001, max_val=0.5):
190
+ """Seems only used during pre-training"""
191
+ self.temp.clamp_(min_val, max_val)
192
+
193
+ @torch.no_grad()
194
+ def get_mask(self, sim, idx=None, normalize=False):
195
+ """
196
+ sim: (N, N)
197
+ idx: (N, )
198
+ normalize: bool, make row sum equal to 1
199
+ """
200
+ if idx is not None:
201
+ idx = idx.view(-1, 1)
202
+ mask = torch.eq(idx, idx.T).to(sim.dtype)
203
+ if normalize:
204
+ mask = mask / mask.sum(1, keepdim=True)
205
+ else:
206
+ mask = torch.zeros_like(sim)
207
+ mask.fill_diagonal_(1)
208
+ return mask # `1` mark valid/matched location
209
+
210
+ def get_contrastive_loss(self, pooled_image_embeds, pooled_text_embeds, idx=None):
211
+ sim_i2t, sim_t2i = self.get_sim(
212
+ pooled_image_embeds, pooled_text_embeds, t=self.temp)
213
+
214
+ with torch.no_grad():
215
+ sim_i2t_targets = self.get_mask(sim_i2t, idx=idx, normalize=True)
216
+ sim_t2i_targets = sim_i2t_targets
217
+
218
+ loss_i2t = -torch.sum(
219
+ F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1).mean()
220
+ loss_t2i = -torch.sum(
221
+ F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1).mean()
222
+
223
+ loss_ita = (loss_i2t + loss_t2i) / 2
224
+ return loss_ita, sim_i2t, sim_t2i
225
+
226
+ def get_sim(self, pooled_image_embeds, pooled_text_embeds, t=1):
227
+ """
228
+ Args:
229
+ pooled_image_embeds: (bsz, num_frms, d)
230
+ pooled_text_embeds: (bsz, d)
231
+ t: temperature
232
+ """
233
+ image_proj = self.vision_proj
234
+ text_proj = self.text_proj
235
+
236
+ image_feat = F.normalize(image_proj(pooled_image_embeds), dim=-1)
237
+ text_feat = F.normalize(text_proj(pooled_text_embeds), dim=-1)
238
+
239
+ if image_feat.ndim == 3:
240
+ sim_i2t = torch.einsum("mld,nd->mln", image_feat, text_feat).mean(1) / t # (N, N)
241
+ else:
242
+ sim_i2t = torch.einsum("md,nd ->mn", image_feat, text_feat) / t # (N, N)
243
+ sim_t2i = sim_i2t.T
244
+ return sim_i2t, sim_t2i
245
+
246
+ def get_itm_loss(self,
247
+ sim_i2t,
248
+ sim_t2i,
249
+ text_embeds,
250
+ text_atts,
251
+ image_embeds,
252
+ image_atts,
253
+ idx=None,
254
+ ):
255
+ """
256
+ sim_i2t, sim_t2i: (N, N)
257
+ text_embeds, text_atts, image_embeds, image_atts: (N, *)
258
+ idx: (N, )
259
+ """
260
+ bsz = len(sim_i2t)
261
+
262
+ with torch.no_grad():
263
+ weights_i2t = F.softmax(sim_i2t+1e-4, dim=1) # (N, N)
264
+ weights_t2i = F.softmax(sim_t2i+1e-4, dim=1)
265
+
266
+ mask = self.get_mask(sim_i2t, idx=idx).bool()
267
+ weights_i2t.masked_fill_(mask, 0)
268
+ weights_t2i.masked_fill_(mask, 0)
269
+
270
+ # select a negative image for each text
271
+ if self.config.itm_hard_neg:
272
+ img_neg_indices = torch.multinomial(weights_t2i, 1).squeeze() #RuntimeError: invalid multinomial distribution (sum of probabilities <= 0)
273
+ else:
274
+ img_neg_indices = self.get_rand_indices(mask, 1).squeeze()
275
+
276
+ image_embeds_neg = image_embeds[img_neg_indices]
277
+
278
+ # select a negative text for each image
279
+ if self.config.itm_hard_neg:
280
+ txt_neg_indices = torch.multinomial(weights_i2t, 1).squeeze()
281
+ else:
282
+ txt_neg_indices = self.get_rand_indices(mask, 1).squeeze()
283
+
284
+ text_embeds_neg = text_embeds[txt_neg_indices]
285
+ text_atts_neg = text_atts[txt_neg_indices] # (N, L, d)
286
+
287
+ # embedding on local gpu
288
+ _text_embeds = text_embeds
289
+ _text_atts = text_atts
290
+ _image_embeds = image_embeds
291
+ _image_atts = image_atts
292
+ # concat embeddings
293
+ text_embeds_all = torch.cat([_text_embeds, _text_embeds, text_embeds_neg], dim=0)
294
+ text_atts_all = torch.cat([_text_atts, _text_atts, text_atts_neg], dim=0)
295
+ image_embeds_all = torch.cat([_image_embeds, image_embeds_neg, _image_embeds], dim=0)
296
+ image_atts_all = torch.cat([_image_atts, _image_atts, _image_atts], dim=0)
297
+
298
+ text_encoder = self.get_text_encoder()
299
+ output = text_encoder(
300
+ encoder_embeds=text_embeds_all,
301
+ attention_mask=text_atts_all,
302
+ encoder_hidden_states=image_embeds_all,
303
+ encoder_attention_mask=image_atts_all,
304
+ return_dict=True,
305
+ mode='fusion',
306
+ )
307
+
308
+ itm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
309
+
310
+ loss_itm = self._get_itm_loss(itm_embeds, enc=self.itm_head)
311
+ itm_embeds_pos = itm_embeds[:bsz] # (N, d)
312
+
313
+ return loss_itm, itm_embeds_pos
314
+
315
+ def _get_itm_loss(self, itm_embeds, enc):
316
+ """
317
+ itm_embeds: (3*N, D)
318
+ enc: nn.Module that projects cls_embeds
319
+ """
320
+ itm_scores = enc(itm_embeds) # (3*N, 2)
321
+ bs = itm_scores.size(0) // 3
322
+ itm_labels = itm_scores.new_ones(3*bs, dtype=torch.long)
323
+ itm_labels[bs:] = 0
324
+ loss_itm = F.cross_entropy(itm_scores, itm_labels)
325
+ return loss_itm
326
+
327
+ def get_rand_indices(self, mask, k):
328
+ """
329
+ Args:
330
+ mask: (N, L) 0 indicates the positions that we can sample, 1 otherwise
331
+ k: #indices to sample at each row
332
+ Returns:
333
+ (N, k) indices
334
+ """
335
+ mask = mask.float()
336
+ mask = mask - 10000 * mask
337
+ mask += torch.randn_like(mask)
338
+ _, indices = torch.sort(mask, dim=1, descending=True)
339
+ indices = indices[:, :k].contiguous()
340
+ return indices
svitt/sparse_config.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import OrderedDict
17
+ from typing import Mapping
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.onnx import OnnxConfig
21
+
22
+
23
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
+ "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/config.json",
25
+ "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/config.json",
26
+ "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/config.json",
27
+ "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/config.json",
28
+ "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/config.json",
29
+ "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/config.json",
30
+ "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/config.json",
31
+ "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/config.json",
32
+ "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/config.json",
33
+ "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/config.json",
34
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/config.json",
35
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json",
36
+ "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/config.json",
37
+ "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/config.json",
38
+ "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/config.json",
39
+ "cl-tohoku/bert-base-japanese": "https://huggingface.co/cl-tohoku/bert-base-japanese/resolve/main/config.json",
40
+ "cl-tohoku/bert-base-japanese-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking/resolve/main/config.json",
41
+ "cl-tohoku/bert-base-japanese-char": "https://huggingface.co/cl-tohoku/bert-base-japanese-char/resolve/main/config.json",
42
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-char-whole-word-masking/resolve/main/config.json",
43
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/config.json",
44
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/config.json",
45
+ "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/config.json",
46
+ # See all BERT models at https://huggingface.co/models?filter=bert
47
+ }
48
+
49
+
50
+ class BertConfig(PretrainedConfig):
51
+ r"""
52
+ This is the configuration class to store the configuration of a [`BertModel`] or a
53
+ [`TFBertModel`]. It is used to instantiate a BERT model according to the specified arguments,
54
+ defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
55
+ to that of the BERT [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture.
56
+
57
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model
58
+ outputs. Read the documentation from [`PretrainedConfig`] for more information.
59
+
60
+
61
+ Args:
62
+ vocab_size (`int`, *optional*, defaults to 30522):
63
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
64
+ `inputs_ids` passed when calling [`BertModel`] or
65
+ [`TFBertModel`].
66
+ hidden_size (`int`, *optional*, defaults to 768):
67
+ Dimensionality of the encoder layers and the pooler layer.
68
+ num_hidden_layers (`int`, *optional*, defaults to 12):
69
+ Number of hidden layers in the Transformer encoder.
70
+ num_attention_heads (`int`, *optional*, defaults to 12):
71
+ Number of attention heads for each attention layer in the Transformer encoder.
72
+ intermediate_size (`int`, *optional*, defaults to 3072):
73
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
74
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
75
+ The non-linear activation function (function or string) in the encoder and pooler. If string,
76
+ `"gelu"`, `"relu"`, `"silu"` and `"gelu_new"` are supported.
77
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
78
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
79
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
80
+ The dropout ratio for the attention probabilities.
81
+ max_position_embeddings (`int`, *optional*, defaults to 512):
82
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
83
+ just in case (e.g., 512 or 1024 or 2048).
84
+ type_vocab_size (`int`, *optional*, defaults to 2):
85
+ The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or
86
+ [`TFBertModel`].
87
+ initializer_range (`float`, *optional*, defaults to 0.02):
88
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
89
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
90
+ The epsilon used by the layer normalization layers.
91
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
92
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`,
93
+ `"relative_key_query"`. For positional embeddings use `"absolute"`. For more information on
94
+ `"relative_key"`, please refer to [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). For more information on `"relative_key_query"`, please refer to
95
+ *Method 4* in [Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
96
+ use_cache (`bool`, *optional*, defaults to `True`):
97
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
98
+ relevant if `config.is_decoder=True`.
99
+ classifier_dropout (`float`, *optional*):
100
+ The dropout ratio for the classification head.
101
+
102
+ Examples:
103
+
104
+ ```python
105
+ >>> from transformers import BertModel, BertConfig
106
+
107
+ >>> # Initializing a BERT bert-base-uncased style configuration
108
+ >>> configuration = BertConfig()
109
+
110
+ >>> # Initializing a model from the bert-base-uncased style configuration
111
+ >>> model = BertModel(configuration)
112
+
113
+ >>> # Accessing the model configuration
114
+ >>> configuration = model.config
115
+ ```"""
116
+ model_type = "bert"
117
+
118
+ def __init__(
119
+ self,
120
+ vocab_size=30522,
121
+ hidden_size=768,
122
+ num_hidden_layers=12,
123
+ num_attention_heads=12,
124
+ intermediate_size=3072,
125
+ hidden_act="gelu",
126
+ hidden_dropout_prob=0.1,
127
+ attention_probs_dropout_prob=0.1,
128
+ max_position_embeddings=512,
129
+ type_vocab_size=2,
130
+ initializer_range=0.02,
131
+ layer_norm_eps=1e-12,
132
+ pad_token_id=0,
133
+ position_embedding_type="absolute",
134
+ use_cache=True,
135
+ classifier_dropout=None,
136
+ token_keep_rate=1,
137
+ token_keep_strategy='cls_attn',
138
+ token_drop_loc=[9],
139
+ **kwargs
140
+ ):
141
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
142
+
143
+ self.vocab_size = vocab_size
144
+ self.hidden_size = hidden_size
145
+ self.num_hidden_layers = num_hidden_layers
146
+ self.num_attention_heads = num_attention_heads
147
+ self.hidden_act = hidden_act
148
+ self.intermediate_size = intermediate_size
149
+ self.hidden_dropout_prob = hidden_dropout_prob
150
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
151
+ self.max_position_embeddings = max_position_embeddings
152
+ self.type_vocab_size = type_vocab_size
153
+ self.initializer_range = initializer_range
154
+ self.layer_norm_eps = layer_norm_eps
155
+ self.position_embedding_type = position_embedding_type
156
+ self.use_cache = use_cache
157
+ self.classifier_dropout = classifier_dropout
158
+ self.token_keep_rate = token_keep_rate
159
+ self.token_keep_strategy = token_keep_strategy
160
+ self.token_drop_loc = token_drop_loc
161
+
162
+
163
+ class BertOnnxConfig(OnnxConfig):
164
+ @property
165
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
166
+ return OrderedDict(
167
+ [
168
+ ("input_ids", {0: "batch", 1: "sequence"}),
169
+ ("attention_mask", {0: "batch", 1: "sequence"}),
170
+ ("token_type_ids", {0: "batch", 1: "sequence"}),
171
+ ]
172
+ )
173
+
174
+
175
+ BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
176
+ "microsoft/beit-base-patch16-224-in22k": "https://huggingface.co/microsoft/beit-base-patch16-224-in22k/resolve/main/config.json",
177
+ # See all BEiT models at https://huggingface.co/models?filter=beit
178
+ }
179
+
180
+
181
+ class BeitConfig(PretrainedConfig):
182
+ r"""
183
+ This is the configuration class to store the configuration of a [`BeitModel`]. It is used to
184
+ instantiate an BEiT model according to the specified arguments, defining the model architecture. Instantiating a
185
+ configuration with the defaults will yield a similar configuration to that of the BEiT
186
+ [microsoft/beit-base-patch16-224-in22k](https://huggingface.co/microsoft/beit-base-patch16-224-in22k)
187
+ architecture.
188
+
189
+ Args:
190
+ vocab_size (`int`, *optional*, defaults to 8092):
191
+ Vocabulary size of the BEiT model. Defines the number of different image tokens that can be used during
192
+ pre-training.
193
+ hidden_size (`int`, *optional*, defaults to 768):
194
+ Dimensionality of the encoder layers and the pooler layer.
195
+ num_hidden_layers (`int`, *optional*, defaults to 12):
196
+ Number of hidden layers in the Transformer encoder.
197
+ num_attention_heads (`int`, *optional*, defaults to 12):
198
+ Number of attention heads for each attention layer in the Transformer encoder.
199
+ intermediate_size (`int`, *optional*, defaults to 3072):
200
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
201
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
202
+ The non-linear activation function (function or string) in the encoder and pooler. If string,
203
+ `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
204
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
205
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
206
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
207
+ The dropout ratio for the attention probabilities.
208
+ initializer_range (`float`, *optional*, defaults to 0.02):
209
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
210
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
211
+ The epsilon used by the layer normalization layers.
212
+ image_size (`int`, *optional*, defaults to `224`):
213
+ The size (resolution) of each image.
214
+ patch_size (`int`, *optional*, defaults to `16`):
215
+ The size (resolution) of each patch.
216
+ num_channels (`int`, *optional*, defaults to `3`):
217
+ The number of input channels.
218
+ use_mask_token (`bool`, *optional*, defaults to `False`):
219
+ Whether to use a mask token for masked image modeling.
220
+ use_absolute_position_embeddings (`bool`, *optional*, defaults to `False`):
221
+ Whether to use BERT-style absolute position embeddings.
222
+ use_relative_position_bias (`bool`, *optional*, defaults to `False`):
223
+ Whether to use T5-style relative position embeddings in the self-attention layers.
224
+ use_shared_relative_position_bias (`bool`, *optional*, defaults to `False`):
225
+ Whether to use the same relative position embeddings across all self-attention layers of the Transformer.
226
+ layer_scale_init_value (`float`, *optional*, defaults to 0.1):
227
+ Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.
228
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
229
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
230
+ use_mean_pooling (`bool`, *optional*, defaults to `True`):
231
+ Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
232
+ CLS token, before applying the classification head.
233
+ out_indices (`List[int]`, *optional*, defaults to `[3, 5, 7, 11]`):
234
+ Indices of the feature maps to use for semantic segmentation.
235
+ pool_scales (`Tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`):
236
+ Pooling scales used in Pooling Pyramid Module applied on the last feature map.
237
+ use_auxiliary_head (`bool`, *optional*, defaults to `True`):
238
+ Whether to use an auxiliary head during training.
239
+ auxiliary_loss_weight (`float`, *optional*, defaults to 0.4):
240
+ Weight of the cross-entropy loss of the auxiliary head.
241
+ auxiliary_channels (`int`, *optional*, defaults to 256):
242
+ Number of channels to use in the auxiliary head.
243
+ auxiliary_num_convs (`int`, *optional*, defaults to 1):
244
+ Number of convolutional layers to use in the auxiliary head.
245
+ auxiliary_concat_input (`bool`, *optional*, defaults to `False`):
246
+ Whether to concatenate the output of the auxiliary head with the input before the classification layer.
247
+ semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
248
+ The index that is ignored by the loss function of the semantic segmentation model.
249
+
250
+ Example:
251
+
252
+ ```python
253
+ >>> from transformers import BeitModel, BeitConfig
254
+
255
+ >>> # Initializing a BEiT beit-base-patch16-224-in22k style configuration
256
+ >>> configuration = BeitConfig()
257
+
258
+ >>> # Initializing a model from the beit-base-patch16-224-in22k style configuration
259
+ >>> model = BeitModel(configuration)
260
+
261
+ >>> # Accessing the model configuration
262
+ >>> configuration = model.config
263
+ ```"""
264
+ model_type = "beit"
265
+
266
+ def __init__(
267
+ self,
268
+ vocab_size=8192,
269
+ hidden_size=768,
270
+ num_hidden_layers=12,
271
+ num_attention_heads=12,
272
+ intermediate_size=3072,
273
+ hidden_act="gelu",
274
+ hidden_dropout_prob=0.0,
275
+ attention_probs_dropout_prob=0.0,
276
+ initializer_range=0.02,
277
+ layer_norm_eps=1e-12,
278
+ is_encoder_decoder=False,
279
+ image_size=224,
280
+ patch_size=16,
281
+ num_channels=3,
282
+ use_mask_token=False,
283
+ use_absolute_position_embeddings=False,
284
+ use_relative_position_bias=False,
285
+ use_shared_relative_position_bias=False,
286
+ layer_scale_init_value=0.1,
287
+ drop_path_rate=0.1,
288
+ use_mean_pooling=True,
289
+ out_indices=[3, 5, 7, 11],
290
+ pool_scales=[1, 2, 3, 6],
291
+ use_auxiliary_head=True,
292
+ auxiliary_loss_weight=0.4,
293
+ auxiliary_channels=256,
294
+ auxiliary_num_convs=1,
295
+ auxiliary_concat_input=False,
296
+ semantic_loss_ignore_index=255,
297
+ token_keep_rate=1,
298
+ token_keep_strategy='cls_attn',
299
+ token_drop_loc=[3, 6, 9],
300
+ sparse_random_attn=None,
301
+ sparse_local_attn=1,
302
+ attn_block_size=1,
303
+ num_cls_tokens=1,
304
+ token_3d_order='none',
305
+ **kwargs
306
+ ):
307
+ super().__init__(**kwargs)
308
+
309
+ self.vocab_size = vocab_size
310
+ self.hidden_size = hidden_size
311
+ self.num_hidden_layers = num_hidden_layers
312
+ self.num_attention_heads = num_attention_heads
313
+ self.intermediate_size = intermediate_size
314
+ self.hidden_act = hidden_act
315
+ self.hidden_dropout_prob = hidden_dropout_prob
316
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
317
+ self.initializer_range = initializer_range
318
+ self.layer_norm_eps = layer_norm_eps
319
+
320
+ self.image_size = image_size
321
+ self.patch_size = patch_size
322
+ self.num_channels = num_channels
323
+ self.use_mask_token = use_mask_token
324
+ self.use_absolute_position_embeddings = use_absolute_position_embeddings
325
+ self.use_relative_position_bias = use_relative_position_bias
326
+ self.use_shared_relative_position_bias = use_shared_relative_position_bias
327
+ self.layer_scale_init_value = layer_scale_init_value
328
+ self.drop_path_rate = drop_path_rate
329
+ self.use_mean_pooling = use_mean_pooling
330
+ # decode head attributes (semantic segmentation)
331
+ self.out_indices = out_indices
332
+ self.pool_scales = pool_scales
333
+ # auxiliary head attributes (semantic segmentation)
334
+ self.use_auxiliary_head = use_auxiliary_head
335
+ self.auxiliary_loss_weight = auxiliary_loss_weight
336
+ self.auxiliary_channels = auxiliary_channels
337
+ self.auxiliary_num_convs = auxiliary_num_convs
338
+ self.auxiliary_concat_input = auxiliary_concat_input
339
+ self.semantic_loss_ignore_index = semantic_loss_ignore_index
340
+
341
+ # node sparsification
342
+ self.token_keep_rate = token_keep_rate
343
+ self.token_keep_strategy = token_keep_strategy
344
+ self.token_drop_loc = token_drop_loc
345
+ # edge sparsification
346
+ self.sparse_random_attn = sparse_random_attn
347
+ self.sparse_local_attn = sparse_local_attn
348
+ self.attn_block_size = attn_block_size
349
+ self.num_cls_tokens = num_cls_tokens
350
+ # token order
351
+ self.token_3d_order = token_3d_order
svitt/sparse_xbeit.py ADDED
@@ -0,0 +1,1585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch BEiT model. """
16
+
17
+
18
+ import collections.abc
19
+ import math
20
+ import numpy as np
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Tuple
23
+ import zCurve
24
+ import hilbert
25
+
26
+ import torch
27
+ import torch.utils.checkpoint
28
+ from torch import nn
29
+ from torch.nn import CrossEntropyLoss, MSELoss
30
+ from einops import rearrange, repeat
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
34
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
35
+ from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
36
+ from svitt.sparse_config import BeitConfig
37
+
38
+
39
+ _CONFIG_FOR_DOC = "BeitConfig"
40
+ _CHECKPOINT_FOR_DOC = "microsoft/beit-base-patch16-224"
41
+
42
+ BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
43
+ "microsoft/beit-base-patch16-224",
44
+ # See all BEiT models at https://huggingface.co/models?filter=beit
45
+ ]
46
+
47
+
48
+ @dataclass
49
+ class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
50
+ """
51
+ Class for outputs of :class:`~transformers.BeitModel`.
52
+
53
+ Args:
54
+ last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
55
+ Sequence of hidden-states at the output of the last layer of the model.
56
+ pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
57
+ Average of the last layer hidden states of the patch tokens (excluding the `[CLS]` token) if
58
+ `config.use_mean_pooling` is set to True. If set to False, then the final hidden state of the `[CLS]` token
59
+ will be returned.
60
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
61
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
62
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
63
+
64
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
65
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
66
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
67
+ sequence_length, sequence_length)`.
68
+
69
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
70
+ heads.
71
+ """
72
+ token_idx: Optional[Tuple[torch.LongTensor]] = None
73
+
74
+
75
+ @dataclass
76
+ class BeitModelOutput(BaseModelOutput):
77
+ token_idx: Optional[Tuple[torch.LongTensor]] = None
78
+
79
+
80
+ # Inspired by
81
+ # https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
82
+ # From PyTorch internals
83
+ def to_2tuple(x):
84
+ if isinstance(x, collections.abc.Iterable):
85
+ return x
86
+ return (x, x)
87
+
88
+
89
+ # Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py
90
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
91
+ """
92
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
93
+
94
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
95
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
96
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
97
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
98
+ argument.
99
+ """
100
+ if drop_prob == 0.0 or not training:
101
+ return x
102
+ keep_prob = 1 - drop_prob
103
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
104
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
105
+ random_tensor.floor_() # binarize
106
+ output = x.div(keep_prob) * random_tensor
107
+ return output
108
+
109
+
110
+ class DropPath(nn.Module):
111
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
112
+
113
+ def __init__(self, drop_prob=None):
114
+ super().__init__()
115
+ self.drop_prob = drop_prob
116
+
117
+ def forward(self, x):
118
+ return drop_path(x, self.drop_prob, self.training)
119
+
120
+ def extra_repr(self) -> str:
121
+ return "p={}".format(self.drop_prob)
122
+
123
+
124
+ # Based on timm implementation, which can be found here:
125
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
126
+ class BeitEmbeddings(nn.Module):
127
+ """
128
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
129
+
130
+ """
131
+
132
+ def __init__(self, config):
133
+ super().__init__()
134
+
135
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
136
+ if config.use_mask_token:
137
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
138
+ else:
139
+ self.mask_token = None
140
+ self.patch_embeddings = PatchEmbeddings(
141
+ image_size=config.image_size,
142
+ patch_size=config.patch_size,
143
+ num_channels=config.num_channels,
144
+ embed_dim=config.hidden_size,
145
+ )
146
+ num_patches = self.patch_embeddings.num_patches
147
+ if config.use_absolute_position_embeddings:
148
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
149
+ else:
150
+ self.position_embeddings = None
151
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
152
+
153
+ def forward(self, pixel_values, bool_masked_pos=None):
154
+
155
+ if pixel_values.ndim == 5: # video input=
156
+ embeddings = self.patch_embeddings(pixel_values.flatten(0, 1))
157
+ embeddings = rearrange(embeddings, '(b m) n d -> b (m n) d', m=pixel_values.shape[1])
158
+ else: # image input
159
+ embeddings = self.patch_embeddings(pixel_values)
160
+
161
+ batch_size, seq_len, _ = embeddings.size()
162
+
163
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
164
+ if bool_masked_pos is not None:
165
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
166
+ # replace the masked visual tokens by mask_tokens
167
+ w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
168
+ embeddings = embeddings * (1 - w) + mask_tokens * w
169
+
170
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
171
+ if self.position_embeddings is not None:
172
+ embeddings = embeddings + self.position_embeddings
173
+ embeddings = self.dropout(embeddings)
174
+
175
+ return embeddings
176
+
177
+
178
+ # Based on timm implementation, which can be found here:
179
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
180
+ class PatchEmbeddings(nn.Module):
181
+ """
182
+ Image to Patch Embedding.
183
+ """
184
+
185
+ def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
186
+ super().__init__()
187
+ image_size = to_2tuple(image_size)
188
+ patch_size = to_2tuple(patch_size)
189
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
190
+ patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
191
+ self.image_size = image_size
192
+ self.patch_size = patch_size
193
+ self.num_patches = num_patches
194
+ self.patch_shape = patch_shape
195
+
196
+ self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
197
+
198
+ def forward(self, pixel_values):
199
+ batch_size, num_channels, height, width = pixel_values.shape
200
+ # FIXME look at relaxing size constraints
201
+ if height != self.image_size[0] or width != self.image_size[1]:
202
+ raise ValueError(
203
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
204
+ )
205
+ x = self.projection(pixel_values).flatten(2).transpose(1, 2)
206
+
207
+ return x
208
+
209
+
210
+ class BeitSelfAttention(nn.Module):
211
+ def __init__(self, config, window_size=None):
212
+ super().__init__()
213
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
214
+ raise ValueError(
215
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
216
+ f"heads {config.num_attention_heads}."
217
+ )
218
+
219
+ self.num_attention_heads = config.num_attention_heads
220
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
221
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
222
+
223
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
224
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
225
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
226
+
227
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
228
+
229
+ # sparse params
230
+ self.random_attn = config.sparse_random_attn
231
+ self.local_attn = config.sparse_local_attn
232
+ self.block_size = config.attn_block_size
233
+ self.num_cls_tokens = config.num_cls_tokens
234
+ if self.local_attn is not None and self.random_attn is not None:
235
+ self.num_kv_blocks = self.local_attn + self.random_attn
236
+
237
+ if window_size:
238
+ self.relative_position_bias = BeitRelativePositionBias3D(config, window_size=window_size)
239
+ else:
240
+ self.relative_position_bias = None
241
+
242
+ def split_heads(self, x):
243
+ return rearrange(x, 'b n (h d) -> b h n d', h=self.num_attention_heads)
244
+
245
+ def join_heads(self, x):
246
+ return rearrange(x, 'b h n d -> b n (h d)')
247
+
248
+ def blockify(self, x):
249
+ assert x.dim() == 4, f"Unsupported input shape {x.shape}"
250
+ seq_len = x.shape[2]
251
+ if seq_len % self.block_size > 0: # seq_len not divisible by block_size, zero pad
252
+ pad_len = self.block_size - seq_len % self.block_size
253
+ x = nn.functional.pad(x, (0, 0, 0, pad_len))
254
+ else:
255
+ pad_len = 0
256
+ x = rearrange(x, 'b h (m n) d -> b h m n d', n=self.block_size)
257
+ return x, pad_len
258
+
259
+ def dense_attention(self, q, k, v, head_mask=None, relative_position_bias=None, q_idx=None, k_idx=None):
260
+ # q, k, v: (bsz, num_heads, seq_len, dims)
261
+ assert k.shape[2] == v.shape[2], "Key and value shapes mismatch"
262
+ sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
263
+ sim = sim / math.sqrt(self.attention_head_size)
264
+
265
+ # Add relative position bias if present.
266
+ if self.relative_position_bias is not None:
267
+ if q_idx is not None and q_idx.ndim == 2:
268
+ assert k_idx is not None and len(q_idx) == len(k_idx)
269
+ bias = torch.stack([
270
+ self.relative_position_bias(from_idx=q_idx_, to_idx=k_idx_)
271
+ for q_idx_, k_idx_ in zip(q_idx, k_idx)
272
+ ])
273
+ else:
274
+ bias = self.relative_position_bias(from_idx=q_idx, to_idx=k_idx).unsqueeze(0)
275
+ sim = sim + bias
276
+
277
+ # Add shared relative position bias if provided.
278
+ if relative_position_bias is not None:
279
+ sim = sim + relative_position_bias
280
+
281
+ # Normalize the attention scores to probabilities.
282
+ attn = sim.softmax(dim=-1)
283
+ attn = self.dropout(attn)
284
+ if head_mask is not None:
285
+ attn = attn * head_mask
286
+
287
+ out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
288
+ return out, attn
289
+
290
+ def _sparse_attn_relative_position_bias(self, q_idx, pad_q, attn_idx, group_len):
291
+ q_idx_blk = nn.functional.pad(q_idx, (0, pad_q)).view(-1, self.block_size)
292
+ attn_idx_flt = rearrange(q_idx_blk[attn_idx], 'm n j -> m (n j)') # (seq_len, num_kv_blocks * group_len)
293
+ cls_idx = torch.arange(self.num_cls_tokens, device=q_idx.device)
294
+ cls_idx = repeat(cls_idx, 'n -> m n', m=len(attn_idx_flt))
295
+ attn_idx_flt = torch.cat((cls_idx, attn_idx_flt), dim=1)
296
+ attn_idx_flt = repeat(attn_idx_flt, 'm n -> (m i) n', i=group_len)
297
+ if pad_q > 0:
298
+ attn_idx_flt = attn_idx_flt[:-pad_q]
299
+ bias_flt = self.relative_position_bias(from_idx=q_idx, to_idx=attn_idx_flt)
300
+ if pad_q > 0:
301
+ bias_flt = nn.functional.pad(bias_flt, (0, 0, 0, pad_q))
302
+ return rearrange(bias_flt, 'h (m i) n -> h m i n', i=group_len) # num_heads, seq_len, group_len, (num_kv_blocks * group_len + num_cls_tokens)
303
+
304
+ def sparse_attention(self, q, k, v, head_mask=None, relative_position_bias=None, q_idx=None, mimic_full=False):
305
+ assert self.local_attn == 0 or self.local_attn % 2 == 1, "Even local window size not supported"
306
+ assert k.shape[2] == v.shape[2], "Key and value shapes mismatch"
307
+
308
+
309
+ if not mimic_full:
310
+ cls_k, k = k[..., :self.num_cls_tokens, :], k[..., self.num_cls_tokens:, :] # cls_k: (bsz, num_heads, num_cls_tokens, dims)
311
+ cls_v, v = v[..., :self.num_cls_tokens, :], v[..., self.num_cls_tokens:, :]
312
+
313
+ # pad token sequence to multiples of block_size
314
+ if mimic_full:
315
+ bsz, num_heads, seq_len, dims = q.shape
316
+ else:
317
+ q, pad_q = self.blockify(q) # q: (bsz, num_heads, seq_len, group_len, dims)
318
+ k, pad_k = self.blockify(k)
319
+ v, pad_v = self.blockify(v)
320
+ bsz, num_heads, seq_len, group_len, dims = q.shape
321
+
322
+ # global attention
323
+ cls_sim = torch.einsum('b h n i d, b h j d -> b h n i j', q, cls_k) # (bsz, num_heads, seq_len, group_len, num_cls_tokens)
324
+
325
+ if mimic_full:
326
+ sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
327
+ sim = sim / math.sqrt(self.attention_head_size)
328
+ sim = sim + self.relative_position_bias(from_idx=q_idx).unsqueeze(0)
329
+
330
+ else:
331
+ # initialize empty sim matrix
332
+ sim = torch.empty((bsz, num_heads, seq_len, self.num_kv_blocks, group_len, group_len), device=q.device)
333
+ attn_idx = torch.zeros((seq_len, self.num_kv_blocks), dtype=torch.int64, device=q.device)
334
+
335
+ # local window attention
336
+ cnt = 0
337
+ if self.local_attn > 0:
338
+ num_rolls = self.local_attn // 2
339
+ for r in range(-num_rolls, num_rolls + 1):
340
+ sim[..., cnt, :, :] = torch.einsum('b h n i d, b h n j d -> b h n i j', q, k.roll(-r, dims=2))
341
+ attn_idx[:, cnt] = torch.arange(seq_len, device=q.device).roll(r)
342
+ cnt += 1
343
+
344
+ # random attention
345
+ if self.random_attn > 0:
346
+ # generate random attention pattern
347
+ rand = torch.rand((seq_len, seq_len), device=q.device)
348
+ if self.local_attn > 0:
349
+ # avoid overlap with local attention
350
+ for r in range(-num_rolls, num_rolls + 1):
351
+ tgt_idx = list(i % seq_len for i in range(r, seq_len + r))
352
+ rand[range(seq_len), tgt_idx] = 0
353
+ _, idx = rand.topk(self.random_attn, dim=-1) # seq_len, random_attn
354
+ idx, _ = torch.sort(idx, dim=1)
355
+ attn_idx[:, cnt:] = idx
356
+
357
+ idx_ = repeat(idx, 'n m -> b h n m i d', b=bsz, h=num_heads, i=group_len, d=dims)
358
+
359
+ for r in range(self.random_attn):
360
+ sim[..., cnt, :, :] = torch.einsum('b h n i d, b h n j d -> b h n i j', q, k.gather(2, idx_[..., r, :, :]))
361
+ cnt += 1
362
+
363
+ sim = rearrange(sim, 'b h m n i j -> b h m i (n j)') # (bsz, num_heads, seq_len, group_len, num_kv_blocks * group_len)
364
+ sim = torch.cat((cls_sim, sim), -1)
365
+ sim = sim / math.sqrt(self.attention_head_size)
366
+
367
+ # Add relative position bias if present.
368
+ # NOTE: we assume q and k (excluding cls) use same token indexing, for relative position embedding
369
+ if self.relative_position_bias is not None:
370
+ assert q_idx is not None, "query index required for relative position bias"
371
+ if q_idx.ndim == 2:
372
+ # different indices for each sample
373
+ bias = torch.stack([
374
+ self._sparse_attn_relative_position_bias(q_idx_, pad_q, attn_idx, group_len)
375
+ for q_idx_ in q_idx
376
+ ])
377
+ else:
378
+ bias = self._sparse_attn_relative_position_bias(q_idx, pad_q, attn_idx, group_len).unsqueeze(0)
379
+ sim = sim + bias
380
+
381
+ # Add shared relative position bias if provided.
382
+ if relative_position_bias is not None:
383
+ raise NotImplementedError
384
+ sim = sim + relative_position_bias
385
+
386
+ attn = sim.softmax(dim=-1)
387
+ attn = self.dropout(attn)
388
+ if head_mask is not None:
389
+ attn = attn * head_mask
390
+
391
+ # block attention
392
+ if mimic_full:
393
+ out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
394
+
395
+ else:
396
+ out = torch.empty((bsz, num_heads, seq_len, group_len, dims), device=q.device)
397
+ for m in range(seq_len):
398
+ v_row = torch.index_select(v, 2, attn_idx[m])
399
+ v_row = rearrange(v_row, 'b h n j d -> b h (n j) d') # (bsz, num_heads, num_kv_blocks * group_len, dims)
400
+ v_row = torch.cat((cls_v, v_row), 2)
401
+ out[..., m, :, :] = torch.einsum('b h i j, b h j d -> b h i d', attn[..., m, :, :], v_row)
402
+ out = rearrange(out, 'b h n i d -> b h (n i) d')
403
+ if pad_q > 0:
404
+ out = out[..., :-pad_q, :]
405
+
406
+ return out, attn
407
+
408
+ def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None, token_idx=None):
409
+ # compute qkv
410
+ q = self.split_heads(self.query(hidden_states))
411
+ k = self.split_heads(self.key(hidden_states))
412
+ v = self.split_heads(self.value(hidden_states))
413
+
414
+ # combine local token_idx with cls tokens
415
+ # NOTE: assume token_idx starts from 0
416
+ cls_q_idx = torch.arange(self.num_cls_tokens, device=q.device)
417
+ if token_idx is not None:
418
+ if token_idx.ndim == 2:
419
+ cls_q_idx = repeat(cls_q_idx, 'n -> b n', b=q.shape[0])
420
+ all_token_idx = torch.cat((cls_q_idx, token_idx + self.num_cls_tokens), dim=-1)
421
+ else:
422
+ all_token_idx = None
423
+
424
+ if self.random_attn is None:
425
+ outputs, attention_probs = self.dense_attention(q, k, v, head_mask=head_mask,
426
+ relative_position_bias=relative_position_bias,
427
+ q_idx=all_token_idx,
428
+ k_idx=all_token_idx)
429
+ cls_attention_probs = attention_probs[..., :self.num_cls_tokens, :]
430
+
431
+ else:
432
+ cls_q, q = q[..., :self.num_cls_tokens, :], q[..., self.num_cls_tokens:, :]
433
+
434
+ # dense global attention (num_cls_tokens, seq_len)
435
+ cls_outputs, cls_attention_probs = self.dense_attention(cls_q, k, v, head_mask=head_mask,
436
+ relative_position_bias=relative_position_bias,
437
+ q_idx=cls_q_idx,
438
+ k_idx=all_token_idx)
439
+
440
+ # sparse local attention (local_seq_len, seq_len)
441
+ if token_idx is None:
442
+ token_idx = torch.arange(q.shape[-2], device=q.device)
443
+ outputs, attention_probs = self.sparse_attention(q, k, v, head_mask=head_mask,
444
+ relative_position_bias=relative_position_bias,
445
+ q_idx=token_idx + self.num_cls_tokens)
446
+
447
+ outputs = torch.cat((cls_outputs, outputs), dim=2)
448
+
449
+ outputs = self.join_heads(outputs)
450
+
451
+ outputs = (outputs, cls_attention_probs) if output_attentions else (outputs,)
452
+
453
+ return outputs
454
+
455
+
456
+ class BeitSelfOutput(nn.Module):
457
+ """
458
+ The residual connection is defined in BeitLayer instead of here (as is the case with other models), due to the
459
+ layernorm applied before each block.
460
+ """
461
+
462
+ def __init__(self, config):
463
+ super().__init__()
464
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
465
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
466
+
467
+ def forward(self, hidden_states, input_tensor, gamma=None):
468
+ hidden_states = self.dense(hidden_states)
469
+ hidden_states = self.dropout(hidden_states)
470
+
471
+ return hidden_states
472
+
473
+
474
+ class BeitAttention(nn.Module):
475
+ def __init__(self, config, window_size=None):
476
+ super().__init__()
477
+ self.attention = BeitSelfAttention(config, window_size=window_size)
478
+ self.output = BeitSelfOutput(config)
479
+ self.pruned_heads = set()
480
+
481
+ def prune_heads(self, heads):
482
+ if len(heads) == 0:
483
+ return
484
+ heads, index = find_pruneable_heads_and_indices(
485
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
486
+ )
487
+
488
+ # Prune linear layers
489
+ self.attention.query = prune_linear_layer(self.attention.query, index)
490
+ self.attention.key = prune_linear_layer(self.attention.key, index)
491
+ self.attention.value = prune_linear_layer(self.attention.value, index)
492
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
493
+
494
+ # Update hyper params and store pruned heads
495
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
496
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
497
+ self.pruned_heads = self.pruned_heads.union(heads)
498
+
499
+ def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None, token_idx=None):
500
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias, token_idx)
501
+
502
+ attention_output = self.output(self_outputs[0], hidden_states)
503
+
504
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
505
+ return outputs
506
+
507
+
508
+ class BeitIntermediate(nn.Module):
509
+ def __init__(self, config):
510
+ super().__init__()
511
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
512
+ if isinstance(config.hidden_act, str):
513
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
514
+ else:
515
+ self.intermediate_act_fn = config.hidden_act
516
+
517
+ def forward(self, hidden_states):
518
+ hidden_states = self.dense(hidden_states)
519
+ hidden_states = self.intermediate_act_fn(hidden_states)
520
+
521
+ return hidden_states
522
+
523
+
524
+ class BeitOutput(nn.Module):
525
+ def __init__(self, config):
526
+ super().__init__()
527
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
528
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
529
+
530
+ def forward(self, hidden_states):
531
+ hidden_states = self.dense(hidden_states)
532
+ hidden_states = self.dropout(hidden_states)
533
+
534
+ return hidden_states
535
+
536
+
537
+ class BeitLayer(nn.Module):
538
+ """This corresponds to the Block class in the timm implementation."""
539
+
540
+ def __init__(self, config, window_size=None, drop_path_rate=0.0,
541
+ token_keep_rate=1.0):
542
+ super().__init__()
543
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
544
+ self.seq_len_dim = 1
545
+ self.attention = BeitAttention(config, window_size=window_size)
546
+ self.intermediate = BeitIntermediate(config)
547
+ self.output = BeitOutput(config)
548
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
549
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
550
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
551
+
552
+ # sparse params
553
+ self.token_keep_rate = token_keep_rate
554
+ self.token_keep_strategy = config.token_keep_strategy
555
+ self.num_cls_tokens = config.num_cls_tokens
556
+
557
+ init_values = config.layer_scale_init_value
558
+ if init_values > 0:
559
+ self.lambda_1 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
560
+ self.lambda_2 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
561
+ else:
562
+ self.lambda_1, self.lambda_2 = None, None
563
+
564
+ def sparsify(self, x, attn):
565
+ x_cls, x_ = x[:, :self.num_cls_tokens], x[:, self.num_cls_tokens:]
566
+ assert 0 < self.token_keep_rate <= 1, "Expected keep rate in range (0, 1]"
567
+ left_tokens = math.ceil(self.token_keep_rate * x_.size(1))
568
+
569
+ if self.token_keep_strategy == 'cls_attn':
570
+ if len(attn.shape) == 4:
571
+ attn = attn.mean(1) # pool over attention heads
572
+ cls_attn = attn[:, 0, self.num_cls_tokens:]
573
+ _, idx = torch.topk(cls_attn, left_tokens, dim=1) # [B, left_tokens]
574
+
575
+ elif self.token_keep_strategy == 'random':
576
+ rand = torch.rand(x_.shape[:2], device=x_.device)
577
+ _, idx = torch.topk(rand, left_tokens, dim=1) # [B, left_tokens]
578
+
579
+ else:
580
+ raise NotImplementedError(f"Sparse strategy {self.token_keep_strategy} is not implemented")
581
+
582
+ idx, _ = torch.sort(idx, dim=1)
583
+ index = idx.unsqueeze(-1).expand(-1, -1, x_.size(-1)) # [B, left_tokens, C]
584
+ outputs = torch.cat((x_cls, x_.gather(1, index)), dim=1).contiguous()
585
+ return outputs, idx
586
+
587
+ def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None, token_idx=None):
588
+ self_attention_outputs = self.attention(
589
+ self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
590
+ head_mask,
591
+ output_attentions=(output_attentions or self.token_keep_rate < 1),
592
+ relative_position_bias=relative_position_bias,
593
+ token_idx=token_idx
594
+ )
595
+ attention_output = self_attention_outputs[0]
596
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
597
+
598
+ # apply lambda_1 if present
599
+ if self.lambda_1 is not None:
600
+ attention_output = self.lambda_1 * attention_output
601
+
602
+ # first residual connection
603
+ hidden_states = self.drop_path(attention_output) + hidden_states
604
+
605
+ # in BEiT, layernorm is also applied after self-attention
606
+ layer_output = self.layernorm_after(hidden_states)
607
+
608
+ layer_output = self.intermediate(layer_output)
609
+ layer_output = self.output(layer_output)
610
+
611
+ if self.lambda_2 is not None:
612
+ layer_output = self.lambda_2 * layer_output
613
+
614
+ # second residual connection
615
+ layer_output = self.drop_path(layer_output) + hidden_states
616
+
617
+ # node sparsification
618
+ if self.token_keep_rate < 1:
619
+ layer_output, token_keep_idx = self.sparsify(layer_output, outputs[0])
620
+ if token_idx is not None:
621
+ if token_idx.ndim == 1:
622
+ token_idx = repeat(token_idx, 'n -> b n', b=len(token_keep_idx))
623
+ token_keep_idx = token_idx.gather(1, token_keep_idx)
624
+ outputs = outputs + (token_keep_idx,)
625
+
626
+ outputs = (layer_output,) + outputs
627
+
628
+ return outputs
629
+
630
+
631
+ class BeitRelativePositionBias(nn.Module):
632
+ def __init__(self, config, window_size):
633
+ super().__init__()
634
+ self.window_size = window_size
635
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
636
+ self.relative_position_bias_table = nn.Parameter(
637
+ torch.zeros(self.num_relative_distance, config.num_attention_heads)
638
+ ) # 2*Wh-1 * 2*Ww-1, nH
639
+ # cls to token & token 2 cls & cls to cls
640
+
641
+ # get pair-wise relative position index for each token inside the window
642
+ coords_h = torch.arange(window_size[0])
643
+ coords_w = torch.arange(window_size[1])
644
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
645
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
646
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
647
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
648
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
649
+ relative_coords[:, :, 1] += window_size[1] - 1
650
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
651
+ relative_position_index = torch.zeros(
652
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
653
+ )
654
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
655
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
656
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
657
+ relative_position_index[0, 0] = self.num_relative_distance - 1
658
+
659
+ self.register_buffer("relative_position_index", relative_position_index, persistent=False)
660
+
661
+ def forward(self):
662
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
663
+ self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
664
+ ) # Wh*Ww,Wh*Ww,nH
665
+
666
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
667
+
668
+
669
+ class BeitRelativePositionBias3D(nn.Module):
670
+ """
671
+ 3D relative position bias
672
+ """
673
+ def __init__(self, config, window_size, num_cls_tokens=1):
674
+ super().__init__()
675
+ self.window_size = window_size
676
+ self.num_cls_tokens = num_cls_tokens
677
+
678
+ relative_size = [w * 2 - 1 for w in window_size]
679
+ self.num_relative_distance = np.prod(relative_size) + 2 * num_cls_tokens + num_cls_tokens ** 2
680
+
681
+ self.relative_position_bias_table = nn.Parameter(
682
+ torch.zeros(self.num_relative_distance, config.num_attention_heads)
683
+ )
684
+
685
+ # get pair-wise relative position index for each token inside the window
686
+ coords_range = [torch.arange(w) for w in window_size]
687
+ coords_flatten = torch.stack(torch.meshgrid(coords_range)).flatten(1)
688
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
689
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
690
+
691
+ for i, w in enumerate(window_size):
692
+ relative_coords[:, :, i] += w - 1 # shift to start from 0
693
+
694
+ for i, r in enumerate(relative_size[1:]):
695
+ relative_coords[:, :, :i + 1] *= r
696
+
697
+ self.seq_len = np.prod(window_size) + num_cls_tokens
698
+ relative_position_index = torch.zeros((self.seq_len, self.seq_len), dtype=relative_coords.dtype)
699
+ relative_position_index[num_cls_tokens:, num_cls_tokens:] = relative_coords.sum(-1)
700
+
701
+ start = np.prod(relative_size)
702
+ cls2loc = torch.arange(num_cls_tokens).unsqueeze(1) + start
703
+ relative_position_index[:num_cls_tokens, num_cls_tokens:] = cls2loc
704
+ start += num_cls_tokens
705
+
706
+ loc2cls = torch.arange(num_cls_tokens).unsqueeze(0) + start
707
+ relative_position_index[num_cls_tokens:, :num_cls_tokens] = loc2cls
708
+ start += num_cls_tokens
709
+
710
+ cls2cls = torch.arange(num_cls_tokens ** 2).view(num_cls_tokens, num_cls_tokens) + start
711
+ relative_position_index[:num_cls_tokens, :num_cls_tokens] = cls2cls
712
+
713
+ self.register_buffer("relative_position_index", relative_position_index)
714
+
715
+ def forward(self, from_idx=None, to_idx=None):
716
+ """
717
+ from_idx: indices of query tokens (1-dim)
718
+ to_idx: indices of key/value tokens (1-dim, or 2-dim w/ one row per query)
719
+ """
720
+ attn_idx = self.relative_position_index
721
+
722
+ # query indices
723
+ if from_idx is not None:
724
+ attn_idx = attn_idx[from_idx]
725
+
726
+ # key indices
727
+ if to_idx is not None:
728
+ assert to_idx.ndim in (1, 2), "to_idx must be 1- or 2-dimensional tensors"
729
+ if to_idx.ndim == 1:
730
+ attn_idx = attn_idx[:, to_idx]
731
+ else:
732
+ attn_idx = attn_idx.gather(1, to_idx)
733
+
734
+ rows, cols = attn_idx.shape
735
+ relative_position_bias = self.relative_position_bias_table[attn_idx.flatten()]
736
+ relative_position_bias = rearrange(relative_position_bias, '(i j) h -> h i j', i=rows, j=cols)
737
+ return relative_position_bias.contiguous()
738
+
739
+
740
+ class BeitEncoder(nn.Module):
741
+ def __init__(self, config, window_size=None):
742
+ super().__init__()
743
+ self.config = config
744
+ if config.use_shared_relative_position_bias:
745
+ self.relative_position_bias = BeitRelativePositionBias3D(config, window_size=window_size)
746
+ else:
747
+ self.relative_position_bias = None
748
+
749
+ self._register_token_order(window_size)
750
+
751
+ # stochastic depth decay rule
752
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
753
+
754
+ # node sparsification
755
+ token_keep_rate = [1] * config.num_hidden_layers
756
+ for loc in config.token_drop_loc:
757
+ token_keep_rate[loc] = config.token_keep_rate
758
+
759
+ self.layer = nn.ModuleList(
760
+ [
761
+ BeitLayer(
762
+ config,
763
+ window_size=window_size if config.use_relative_position_bias else None,
764
+ drop_path_rate=dpr[i], token_keep_rate=token_keep_rate[i]
765
+ )
766
+ for i in range(config.num_hidden_layers)
767
+ ]
768
+ )
769
+
770
+ self.gradient_checkpointing = False
771
+
772
+ def _register_token_order(self, shape):
773
+ if self.config.token_3d_order == 'none':
774
+ order = None
775
+ elif self.config.token_3d_order == 'zcurve':
776
+ nbits = max(shape).bit_length()
777
+ coords = list(np.ndindex(*shape))
778
+ order = zCurve.par_interlace(coords, len(shape), nbits)
779
+ order = torch.tensor(np.argsort(order))
780
+ elif self.config.token_3d_order == 'hilbert':
781
+ nbits = max(shape).bit_length()
782
+ coords = list(np.ndindex(*shape))
783
+ order = hilbert.encode(np.stack(coords), len(shape), nbits)
784
+ order = torch.tensor(np.argsort(order))
785
+ else:
786
+ raise NotImplementedError(f"Token ordering {self.config.token_3d_order} not supported")
787
+
788
+ if order is not None:
789
+ self.register_buffer('token_order', order, persistent=False)
790
+ else:
791
+ self.token_order = None
792
+
793
+ def forward(
794
+ self,
795
+ hidden_states,
796
+ head_mask=None,
797
+ output_attentions=False,
798
+ output_hidden_states=False,
799
+ output_token_idx=False,
800
+ return_dict=True,
801
+ ):
802
+ all_hidden_states = () if output_hidden_states else None
803
+ all_self_attentions = () if output_attentions else None
804
+ all_token_idx = () if output_token_idx else None
805
+
806
+ token_idx = self.token_order
807
+ if token_idx is not None:
808
+ cls_states, local_states = hidden_states[:, :self.config.num_cls_tokens], hidden_states[:, self.config.num_cls_tokens:]
809
+ local_states = torch.index_select(local_states, dim=1, index=token_idx)
810
+ hidden_states = torch.cat((cls_states, local_states), 1)
811
+
812
+ for i, layer_module in enumerate(self.layer):
813
+ if output_hidden_states:
814
+ all_hidden_states = all_hidden_states + (hidden_states,)
815
+
816
+ layer_head_mask = head_mask[i] if head_mask is not None else None
817
+
818
+ if self.gradient_checkpointing and self.training:
819
+
820
+ def create_custom_forward(module):
821
+ def custom_forward(*inputs):
822
+ return module(*inputs, output_attentions)
823
+
824
+ return custom_forward
825
+
826
+ layer_outputs = torch.utils.checkpoint.checkpoint(
827
+ create_custom_forward(layer_module),
828
+ hidden_states,
829
+ layer_head_mask,
830
+ )
831
+ else:
832
+ relative_position_bias = (
833
+ self.relative_position_bias() if self.relative_position_bias is not None else None
834
+ )
835
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias, token_idx)
836
+
837
+ hidden_states = layer_outputs[0]
838
+
839
+ if layer_module.token_keep_rate < 1:
840
+ token_idx = layer_outputs[-1]
841
+
842
+ if output_token_idx:
843
+ all_token_idx = all_token_idx + (token_idx,)
844
+
845
+ if output_attentions:
846
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
847
+
848
+ if output_hidden_states:
849
+ all_hidden_states = all_hidden_states + (hidden_states,)
850
+
851
+ if not return_dict:
852
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
853
+ return BeitModelOutput(
854
+ last_hidden_state=hidden_states,
855
+ hidden_states=all_hidden_states,
856
+ attentions=all_self_attentions,
857
+ token_idx=all_token_idx
858
+ )
859
+
860
+
861
+ class BeitPreTrainedModel(PreTrainedModel):
862
+ """
863
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
864
+ models.
865
+ """
866
+
867
+ config_class = BeitConfig
868
+ base_model_prefix = "beit"
869
+ supports_gradient_checkpointing = True
870
+
871
+ def _init_weights(self, module):
872
+ """Initialize the weights"""
873
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
874
+ # Slightly different from the TF version which uses truncated_normal for initialization
875
+ # cf https://github.com/pytorch/pytorch/pull/5617
876
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
877
+ if module.bias is not None:
878
+ module.bias.data.zero_()
879
+ elif isinstance(module, nn.Embedding):
880
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
881
+ if module.padding_idx is not None:
882
+ module.weight.data[module.padding_idx].zero_()
883
+ elif isinstance(module, nn.LayerNorm):
884
+ module.bias.data.zero_()
885
+ module.weight.data.fill_(1.0)
886
+
887
+ def _set_gradient_checkpointing(self, module, value=False):
888
+ if isinstance(module, BeitEncoder):
889
+ module.gradient_checkpointing = value
890
+
891
+
892
+ BEIT_START_DOCSTRING = r"""
893
+ This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
894
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
895
+ behavior.
896
+
897
+ Parameters:
898
+ config (:class:`~transformers.BeitConfig`): Model configuration class with all the parameters of the model.
899
+ Initializing with a config file does not load the weights associated with the model, only the
900
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
901
+ weights.
902
+ """
903
+
904
+ BEIT_INPUTS_DOCSTRING = r"""
905
+ Args:
906
+ pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`):
907
+ Pixel values. Pixel values can be obtained using :class:`~transformers.BeitFeatureExtractor`. See
908
+ :meth:`transformers.BeitFeatureExtractor.__call__` for details.
909
+
910
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
911
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
912
+
913
+ - 1 indicates the head is **not masked**,
914
+ - 0 indicates the head is **masked**.
915
+
916
+ output_attentions (:obj:`bool`, `optional`):
917
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
918
+ tensors for more detail.
919
+ output_hidden_states (:obj:`bool`, `optional`):
920
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
921
+ more detail.
922
+ return_dict (:obj:`bool`, `optional`):
923
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
924
+ """
925
+
926
+
927
+ @add_start_docstrings(
928
+ "The bare Beit Model transformer outputting raw hidden-states without any specific head on top.",
929
+ BEIT_START_DOCSTRING,
930
+ )
931
+ class BeitModel(BeitPreTrainedModel):
932
+ def __init__(self, config, add_pooling_layer=True, num_frames=None):
933
+ super().__init__(config)
934
+ self.config = config
935
+
936
+ self.embeddings = BeitEmbeddings(config)
937
+ self.window_size = self.embeddings.patch_embeddings.patch_shape
938
+ if num_frames is not None:
939
+ self.window_size = (num_frames,) + self.window_size
940
+ self.encoder = BeitEncoder(config, window_size=self.window_size)
941
+
942
+ self.layernorm = (
943
+ nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
944
+ )
945
+ self.pooler = BeitPooler(config) if add_pooling_layer else None
946
+
947
+ # Initialize weights and apply final processing
948
+ self.post_init()
949
+
950
+ def get_input_embeddings(self):
951
+ return self.embeddings.patch_embeddings
952
+
953
+ def _prune_heads(self, heads_to_prune):
954
+ """
955
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
956
+ class PreTrainedModel
957
+ """
958
+ for layer, heads in heads_to_prune.items():
959
+ self.encoder.layer[layer].attention.prune_heads(heads)
960
+
961
+ @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
962
+ @replace_return_docstrings(output_type=BeitModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
963
+ def forward(
964
+ self,
965
+ pixel_values=None,
966
+ bool_masked_pos=None,
967
+ head_mask=None,
968
+ output_attentions=None,
969
+ output_hidden_states=None,
970
+ output_token_idx=None,
971
+ return_dict=None,
972
+ ):
973
+ r"""
974
+ Returns:
975
+
976
+ Examples::
977
+
978
+ >>> from transformers import BeitFeatureExtractor, BeitModel
979
+ >>> from PIL import Image
980
+ >>> import requests
981
+
982
+ >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
983
+ >>> image = Image.open(requests.get(url, stream=True).raw)
984
+
985
+ >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
986
+ >>> model = BeitModel.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
987
+
988
+ >>> inputs = feature_extractor(images=image, return_tensors="pt")
989
+ >>> outputs = model(**inputs)
990
+ >>> last_hidden_states = outputs.last_hidden_state
991
+ """
992
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
993
+ output_hidden_states = (
994
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
995
+ )
996
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
997
+
998
+ if pixel_values is None:
999
+ raise ValueError("You have to specify pixel_values")
1000
+
1001
+ # Prepare head mask if needed
1002
+ # 1.0 in head_mask indicate we keep the head
1003
+ # attention_probs has shape bsz x n_heads x N x N
1004
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1005
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1006
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1007
+
1008
+ embedding_output = self.embeddings(pixel_values, bool_masked_pos)
1009
+
1010
+ encoder_outputs = self.encoder(
1011
+ embedding_output,
1012
+ head_mask=head_mask,
1013
+ output_attentions=output_attentions,
1014
+ output_hidden_states=output_hidden_states,
1015
+ output_token_idx=output_token_idx,
1016
+ return_dict=return_dict,
1017
+ )
1018
+ sequence_output = encoder_outputs[0]
1019
+ sequence_output = self.layernorm(sequence_output)
1020
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1021
+
1022
+ if not return_dict:
1023
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1024
+
1025
+ return BeitModelOutputWithPooling(
1026
+ last_hidden_state=sequence_output,
1027
+ pooler_output=pooled_output,
1028
+ hidden_states=encoder_outputs.hidden_states,
1029
+ attentions=encoder_outputs.attentions,
1030
+ token_idx=encoder_outputs.token_idx,
1031
+ )
1032
+
1033
+
1034
+ class BeitPooler(nn.Module):
1035
+ def __init__(self, config):
1036
+ super().__init__()
1037
+ self.layernorm = (
1038
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None
1039
+ )
1040
+
1041
+ def forward(self, hidden_states):
1042
+ if self.layernorm is not None:
1043
+ # Mean pool the final hidden states of the patch tokens
1044
+ patch_tokens = hidden_states[:, 1:, :]
1045
+ pooled_output = self.layernorm(patch_tokens.mean(1))
1046
+ else:
1047
+ # Pool by simply taking the final hidden state of the [CLS] token
1048
+ pooled_output = hidden_states[:, 0]
1049
+
1050
+ return pooled_output
1051
+
1052
+
1053
+ @add_start_docstrings(
1054
+ "Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).",
1055
+ BEIT_START_DOCSTRING,
1056
+ )
1057
+ class BeitForMaskedImageModeling(BeitPreTrainedModel):
1058
+ def __init__(self, config):
1059
+ super().__init__(config)
1060
+
1061
+ self.num_labels = config.num_labels
1062
+ self.beit = BeitModel(config, add_pooling_layer=False)
1063
+
1064
+ # Classifier head
1065
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1066
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
1067
+
1068
+ # Initialize weights and apply final processing
1069
+ self.post_init()
1070
+
1071
+ @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
1072
+ @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
1073
+ def forward(
1074
+ self,
1075
+ pixel_values=None,
1076
+ bool_masked_pos=None,
1077
+ head_mask=None,
1078
+ labels=None,
1079
+ output_attentions=None,
1080
+ output_hidden_states=None,
1081
+ return_dict=None,
1082
+ ):
1083
+ r"""
1084
+ bool_masked_pos (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, num_patches)`):
1085
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
1086
+
1087
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1088
+ Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ...,
1089
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1090
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1091
+
1092
+ Returns:
1093
+
1094
+ Examples::
1095
+
1096
+ >>> from transformers import BeitFeatureExtractor, BeitForMaskedImageModeling
1097
+ >>> from PIL import Image
1098
+ >>> import requests
1099
+
1100
+ >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
1101
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1102
+
1103
+ >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k')
1104
+ >>> model = BeitForMaskedImageModeling.from_pretrained('microsoft/beit-base-patch16-224-pt22k')
1105
+
1106
+ >>> inputs = feature_extractor(images=image, return_tensors="pt")
1107
+ >>> outputs = model(**inputs)
1108
+ >>> logits = outputs.logits
1109
+ """
1110
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1111
+
1112
+ outputs = self.beit(
1113
+ pixel_values,
1114
+ bool_masked_pos=bool_masked_pos,
1115
+ head_mask=head_mask,
1116
+ output_attentions=output_attentions,
1117
+ output_hidden_states=output_hidden_states,
1118
+ return_dict=return_dict,
1119
+ )
1120
+
1121
+ sequence_output = outputs[0]
1122
+ sequence_output = self.layernorm(sequence_output)
1123
+ prediction_scores = self.lm_head(sequence_output[:, 1:])
1124
+
1125
+ masked_lm_loss = None
1126
+ if labels is not None:
1127
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1128
+ masked_lm_loss = loss_fct(prediction_scores[bool_masked_pos], labels)
1129
+
1130
+ if not return_dict:
1131
+ output = (prediction_scores,) + outputs[2:]
1132
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1133
+
1134
+ return MaskedLMOutput(
1135
+ loss=masked_lm_loss,
1136
+ logits=prediction_scores,
1137
+ hidden_states=outputs.hidden_states,
1138
+ attentions=outputs.attentions,
1139
+ )
1140
+
1141
+
1142
+ @add_start_docstrings(
1143
+ """
1144
+ Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final
1145
+ hidden states of the patch tokens) e.g. for ImageNet.
1146
+ """,
1147
+ BEIT_START_DOCSTRING,
1148
+ )
1149
+ class BeitForImageClassification(BeitPreTrainedModel):
1150
+ def __init__(self, config):
1151
+ super().__init__(config)
1152
+
1153
+ self.num_labels = config.num_labels
1154
+ self.beit = BeitModel(config, add_pooling_layer=True)
1155
+
1156
+ # Classifier head
1157
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
1158
+
1159
+ # Initialize weights and apply final processing
1160
+ self.post_init()
1161
+
1162
+ @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
1163
+ @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
1164
+ def forward(
1165
+ self,
1166
+ pixel_values=None,
1167
+ head_mask=None,
1168
+ labels=None,
1169
+ output_attentions=None,
1170
+ output_hidden_states=None,
1171
+ return_dict=None,
1172
+ ):
1173
+ r"""
1174
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1175
+ Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ...,
1176
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1177
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1178
+
1179
+ Returns:
1180
+
1181
+ Examples::
1182
+
1183
+ >>> from transformers import BeitFeatureExtractor, BeitForImageClassification
1184
+ >>> from PIL import Image
1185
+ >>> import requests
1186
+
1187
+ >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
1188
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1189
+
1190
+ >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
1191
+ >>> model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
1192
+
1193
+ >>> inputs = feature_extractor(images=image, return_tensors="pt")
1194
+ >>> outputs = model(**inputs)
1195
+ >>> logits = outputs.logits
1196
+ >>> # model predicts one of the 1000 ImageNet classes
1197
+ >>> predicted_class_idx = logits.argmax(-1).item()
1198
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
1199
+ """
1200
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1201
+
1202
+ outputs = self.beit(
1203
+ pixel_values,
1204
+ head_mask=head_mask,
1205
+ output_attentions=output_attentions,
1206
+ output_hidden_states=output_hidden_states,
1207
+ return_dict=return_dict,
1208
+ )
1209
+
1210
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
1211
+
1212
+ logits = self.classifier(pooled_output)
1213
+
1214
+ loss = None
1215
+ if labels is not None:
1216
+ if self.num_labels == 1:
1217
+ # We are doing regression
1218
+ loss_fct = MSELoss()
1219
+ loss = loss_fct(logits.view(-1), labels.view(-1))
1220
+ else:
1221
+ loss_fct = CrossEntropyLoss()
1222
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1223
+
1224
+ if not return_dict:
1225
+ output = (logits,) + outputs[2:]
1226
+ return ((loss,) + output) if loss is not None else output
1227
+
1228
+ return SequenceClassifierOutput(
1229
+ loss=loss,
1230
+ logits=logits,
1231
+ hidden_states=outputs.hidden_states,
1232
+ attentions=outputs.attentions,
1233
+ )
1234
+
1235
+
1236
+ class BeitConvModule(nn.Module):
1237
+ """
1238
+ A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
1239
+ layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
1240
+
1241
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1242
+ """
1243
+
1244
+ def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False, dilation=1):
1245
+ super().__init__()
1246
+ self.conv = nn.Conv2d(
1247
+ in_channels=in_channels,
1248
+ out_channels=out_channels,
1249
+ kernel_size=kernel_size,
1250
+ padding=padding,
1251
+ bias=bias,
1252
+ dilation=dilation,
1253
+ )
1254
+ self.bn = nn.BatchNorm2d(out_channels)
1255
+ self.activation = nn.ReLU()
1256
+
1257
+ def forward(self, input):
1258
+ output = self.conv(input)
1259
+ output = self.bn(output)
1260
+ output = self.activation(output)
1261
+
1262
+ return output
1263
+
1264
+
1265
+ class BeitPyramidPoolingModule(nn.ModuleList):
1266
+ """
1267
+ Pyramid Pooling Module (PPM) used in PSPNet.
1268
+
1269
+ Args:
1270
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
1271
+ Module.
1272
+ in_channels (int): Input channels.
1273
+ channels (int): Channels after modules, before conv_seg.
1274
+ align_corners (bool): align_corners argument of F.interpolate.
1275
+
1276
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1277
+ """
1278
+
1279
+ def __init__(self, pool_scales, in_channels, channels, align_corners):
1280
+ super().__init__()
1281
+ self.pool_scales = pool_scales
1282
+ self.align_corners = align_corners
1283
+ self.in_channels = in_channels
1284
+ self.channels = channels
1285
+ for pool_scale in pool_scales:
1286
+ self.append(
1287
+ nn.Sequential(
1288
+ nn.AdaptiveAvgPool2d(pool_scale),
1289
+ BeitConvModule(self.in_channels, self.channels, kernel_size=1),
1290
+ )
1291
+ )
1292
+
1293
+ def forward(self, x):
1294
+ ppm_outs = []
1295
+ for ppm in self:
1296
+ ppm_out = ppm(x)
1297
+ upsampled_ppm_out = nn.functional.interpolate(
1298
+ ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
1299
+ )
1300
+ ppm_outs.append(upsampled_ppm_out)
1301
+ return ppm_outs
1302
+
1303
+
1304
+ class BeitUperHead(nn.Module):
1305
+ """
1306
+ Unified Perceptual Parsing for Scene Understanding. This head is the implementation of `UPerNet
1307
+ <https://arxiv.org/abs/1807.10221>`_.
1308
+
1309
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1310
+ """
1311
+
1312
+ def __init__(self, config):
1313
+ super().__init__()
1314
+
1315
+ self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
1316
+ self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
1317
+ self.channels = config.hidden_size
1318
+ self.align_corners = False
1319
+ self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
1320
+
1321
+ # PSP Module
1322
+ self.psp_modules = BeitPyramidPoolingModule(
1323
+ self.pool_scales,
1324
+ self.in_channels[-1],
1325
+ self.channels,
1326
+ align_corners=self.align_corners,
1327
+ )
1328
+ self.bottleneck = BeitConvModule(
1329
+ self.in_channels[-1] + len(self.pool_scales) * self.channels,
1330
+ self.channels,
1331
+ kernel_size=3,
1332
+ padding=1,
1333
+ )
1334
+ # FPN Module
1335
+ self.lateral_convs = nn.ModuleList()
1336
+ self.fpn_convs = nn.ModuleList()
1337
+ for in_channels in self.in_channels[:-1]: # skip the top layer
1338
+ l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1)
1339
+ fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1)
1340
+ self.lateral_convs.append(l_conv)
1341
+ self.fpn_convs.append(fpn_conv)
1342
+
1343
+ self.fpn_bottleneck = BeitConvModule(
1344
+ len(self.in_channels) * self.channels,
1345
+ self.channels,
1346
+ kernel_size=3,
1347
+ padding=1,
1348
+ )
1349
+
1350
+ def psp_forward(self, inputs):
1351
+ x = inputs[-1]
1352
+ psp_outs = [x]
1353
+ psp_outs.extend(self.psp_modules(x))
1354
+ psp_outs = torch.cat(psp_outs, dim=1)
1355
+ output = self.bottleneck(psp_outs)
1356
+
1357
+ return output
1358
+
1359
+ def forward(self, encoder_hidden_states):
1360
+ # build laterals
1361
+ laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
1362
+
1363
+ laterals.append(self.psp_forward(encoder_hidden_states))
1364
+
1365
+ # build top-down path
1366
+ used_backbone_levels = len(laterals)
1367
+ for i in range(used_backbone_levels - 1, 0, -1):
1368
+ prev_shape = laterals[i - 1].shape[2:]
1369
+ laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
1370
+ laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
1371
+ )
1372
+
1373
+ # build outputs
1374
+ fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
1375
+ # append psp feature
1376
+ fpn_outs.append(laterals[-1])
1377
+
1378
+ for i in range(used_backbone_levels - 1, 0, -1):
1379
+ fpn_outs[i] = nn.functional.interpolate(
1380
+ fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
1381
+ )
1382
+ fpn_outs = torch.cat(fpn_outs, dim=1)
1383
+ output = self.fpn_bottleneck(fpn_outs)
1384
+ output = self.classifier(output)
1385
+
1386
+ return output
1387
+
1388
+
1389
+ class BeitFCNHead(nn.Module):
1390
+ """
1391
+ Fully Convolution Networks for Semantic Segmentation. This head is implemented of `FCNNet
1392
+ <https://arxiv.org/abs/1411.4038>`_.
1393
+
1394
+ Args:
1395
+ config (BeitConfig): Configuration.
1396
+ in_channels
1397
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
1398
+ dilation (int): The dilation rate for convs in the head. Default: 1.
1399
+
1400
+
1401
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1402
+ """
1403
+
1404
+ def __init__(self, config, in_index=2, kernel_size=3, dilation=1):
1405
+ super().__init__()
1406
+ self.in_channels = config.hidden_size
1407
+ self.channels = config.auxiliary_channels
1408
+ self.num_convs = config.auxiliary_num_convs
1409
+ self.concat_input = config.auxiliary_concat_input
1410
+ self.in_index = in_index
1411
+
1412
+ conv_padding = (kernel_size // 2) * dilation
1413
+ convs = []
1414
+ convs.append(
1415
+ BeitConvModule(
1416
+ self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
1417
+ )
1418
+ )
1419
+ for i in range(self.num_convs - 1):
1420
+ convs.append(
1421
+ BeitConvModule(
1422
+ self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
1423
+ )
1424
+ )
1425
+ if self.num_convs == 0:
1426
+ self.convs = nn.Identity()
1427
+ else:
1428
+ self.convs = nn.Sequential(*convs)
1429
+ if self.concat_input:
1430
+ self.conv_cat = BeitConvModule(
1431
+ self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
1432
+ )
1433
+
1434
+ self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
1435
+
1436
+ def forward(self, encoder_hidden_states):
1437
+ # just take the relevant feature maps
1438
+ hidden_states = encoder_hidden_states[self.in_index]
1439
+ output = self.convs(hidden_states)
1440
+ if self.concat_input:
1441
+ output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
1442
+ output = self.classifier(output)
1443
+ return output
1444
+
1445
+
1446
+ @add_start_docstrings(
1447
+ """
1448
+ Beit Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
1449
+ """,
1450
+ BEIT_START_DOCSTRING,
1451
+ )
1452
+ class BeitForSemanticSegmentation(BeitPreTrainedModel):
1453
+ def __init__(self, config):
1454
+ super().__init__(config)
1455
+
1456
+ self.num_labels = config.num_labels
1457
+ self.beit = BeitModel(config, add_pooling_layer=False)
1458
+
1459
+ # FPNs
1460
+ self.fpn1 = nn.Sequential(
1461
+ nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
1462
+ nn.BatchNorm2d(config.hidden_size),
1463
+ nn.GELU(),
1464
+ nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
1465
+ )
1466
+ self.fpn2 = nn.Sequential(
1467
+ nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
1468
+ )
1469
+ self.fpn3 = nn.Identity()
1470
+ self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
1471
+
1472
+ # Semantic segmentation head(s)
1473
+ self.decode_head = BeitUperHead(config)
1474
+ self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None
1475
+
1476
+ # Initialize weights and apply final processing
1477
+ self.post_init()
1478
+
1479
+ def compute_loss(self, logits, auxiliary_logits, labels):
1480
+ # upsample logits to the images' original size
1481
+ upsampled_logits = nn.functional.interpolate(
1482
+ logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
1483
+ )
1484
+ if auxiliary_logits is not None:
1485
+ upsampled_auxiliary_logits = nn.functional.interpolate(
1486
+ auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
1487
+ )
1488
+ # compute weighted loss
1489
+ loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
1490
+ main_loss = loss_fct(upsampled_logits, labels)
1491
+ auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
1492
+ loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
1493
+
1494
+ return loss
1495
+
1496
+ @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
1497
+ @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
1498
+ def forward(
1499
+ self,
1500
+ pixel_values=None,
1501
+ head_mask=None,
1502
+ labels=None,
1503
+ output_attentions=None,
1504
+ output_hidden_states=None,
1505
+ return_dict=None,
1506
+ ):
1507
+ r"""
1508
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, height, width)`, `optional`):
1509
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in :obj:`[0, ...,
1510
+ config.num_labels - 1]`. If :obj:`config.num_labels > 1`, a classification loss is computed
1511
+ (Cross-Entropy).
1512
+
1513
+ Returns:
1514
+
1515
+ Examples::
1516
+
1517
+ >>> from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation
1518
+ >>> from PIL import Image
1519
+ >>> import requests
1520
+
1521
+ >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
1522
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1523
+
1524
+ >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-finetuned-ade-640-640')
1525
+ >>> model = BeitForSemanticSegmentation.from_pretrained('microsoft/beit-base-finetuned-ade-640-640')
1526
+
1527
+ >>> inputs = feature_extractor(images=image, return_tensors="pt")
1528
+ >>> outputs = model(**inputs)
1529
+ >>> # logits are of shape (batch_size, num_labels, height/4, width/4)
1530
+ >>> logits = outputs.logits
1531
+ """
1532
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1533
+ output_hidden_states = (
1534
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1535
+ )
1536
+
1537
+ outputs = self.beit(
1538
+ pixel_values,
1539
+ head_mask=head_mask,
1540
+ output_attentions=output_attentions,
1541
+ output_hidden_states=True, # we need the intermediate hidden states
1542
+ return_dict=return_dict,
1543
+ )
1544
+
1545
+ encoder_hidden_states = outputs.hidden_states if return_dict else outputs[2]
1546
+
1547
+ # only keep certain features, and reshape
1548
+ # note that we do +1 as the encoder_hidden_states also includes the initial embeddings
1549
+ features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
1550
+ batch_size = pixel_values.shape[0]
1551
+ patch_resolution = self.config.image_size // self.config.patch_size
1552
+ features = [
1553
+ x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features
1554
+ ]
1555
+
1556
+ # apply FPNs
1557
+ ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
1558
+ for i in range(len(features)):
1559
+ features[i] = ops[i](features[i])
1560
+
1561
+ logits = self.decode_head(features)
1562
+ auxiliary_logits = None
1563
+ if self.auxiliary_head is not None:
1564
+ auxiliary_logits = self.auxiliary_head(features)
1565
+
1566
+ loss = None
1567
+ if labels is not None:
1568
+ if self.config.num_labels == 1:
1569
+ raise ValueError("The number of labels should be greater than one")
1570
+ else:
1571
+ loss = self.compute_loss(logits, auxiliary_logits, labels)
1572
+
1573
+ if not return_dict:
1574
+ if output_hidden_states:
1575
+ output = (logits,) + outputs[2:]
1576
+ else:
1577
+ output = (logits,) + outputs[3:]
1578
+ return ((loss,) + output) if loss is not None else output
1579
+
1580
+ return SequenceClassifierOutput(
1581
+ loss=loss,
1582
+ logits=logits,
1583
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1584
+ attentions=outputs.attentions,
1585
+ )
svitt/sparse_xbert.py ADDED
@@ -0,0 +1,2039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model. """
17
+
18
+ import math
19
+ import os
20
+ import warnings
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Tuple
23
+
24
+ import torch
25
+ from torch import Tensor, device, nn
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import CrossEntropyLoss, MSELoss
29
+ import torch.nn.functional as F
30
+
31
+ from transformers.activations import ACT2FN
32
+ from transformers.file_utils import (
33
+ ModelOutput,
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ replace_return_docstrings,
37
+ )
38
+ from transformers.modeling_outputs import (
39
+ BaseModelOutputWithPastAndCrossAttentions,
40
+ BaseModelOutputWithPoolingAndCrossAttentions,
41
+ CausalLMOutputWithCrossAttentions,
42
+ MaskedLMOutput,
43
+ MultipleChoiceModelOutput,
44
+ NextSentencePredictorOutput,
45
+ QuestionAnsweringModelOutput,
46
+ SequenceClassifierOutput,
47
+ TokenClassifierOutput,
48
+ )
49
+ from transformers.modeling_utils import (
50
+ PreTrainedModel,
51
+ apply_chunking_to_forward,
52
+ find_pruneable_heads_and_indices,
53
+ prune_linear_layer,
54
+ )
55
+ from svitt.sparse_config import BertConfig
56
+
57
+ import transformers
58
+ transformers.logging.set_verbosity_error()
59
+
60
+
61
+ _CONFIG_FOR_DOC = "BertConfig"
62
+ _TOKENIZER_FOR_DOC = "BertTokenizer"
63
+
64
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
65
+ "bert-base-uncased",
66
+ "bert-large-uncased",
67
+ "bert-base-cased",
68
+ "bert-large-cased",
69
+ "bert-base-multilingual-uncased",
70
+ "bert-base-multilingual-cased",
71
+ "bert-base-chinese",
72
+ "bert-base-german-cased",
73
+ "bert-large-uncased-whole-word-masking",
74
+ "bert-large-cased-whole-word-masking",
75
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
76
+ "bert-large-cased-whole-word-masking-finetuned-squad",
77
+ "bert-base-cased-finetuned-mrpc",
78
+ "bert-base-german-dbmdz-cased",
79
+ "bert-base-german-dbmdz-uncased",
80
+ "cl-tohoku/bert-base-japanese",
81
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
82
+ "cl-tohoku/bert-base-japanese-char",
83
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
84
+ "TurkuNLP/bert-base-finnish-cased-v1",
85
+ "TurkuNLP/bert-base-finnish-uncased-v1",
86
+ "wietsedv/bert-base-dutch-cased",
87
+ # See all BERT models at https://huggingface.co/models?filter=bert
88
+ ]
89
+
90
+
91
+ @dataclass
92
+ class BertModelOutputWithPastAndCrossAttentions(BaseModelOutputWithPastAndCrossAttentions):
93
+ token_idx: Optional[Tuple[torch.LongTensor]] = None
94
+
95
+
96
+ @dataclass
97
+ class BertModelOutputWithPoolingAndCrossAttentions(BaseModelOutputWithPoolingAndCrossAttentions):
98
+ token_idx: Optional[Tuple[torch.LongTensor]] = None
99
+
100
+
101
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
102
+ """Load tf checkpoints in a pytorch model."""
103
+ try:
104
+ import re
105
+
106
+ import numpy as np
107
+ import tensorflow as tf
108
+ except ImportError:
109
+ print(
110
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
111
+ "https://www.tensorflow.org/install/ for installation instructions."
112
+ )
113
+ raise
114
+ tf_path = os.path.abspath(tf_checkpoint_path)
115
+ # Load weights from TF model
116
+ init_vars = tf.train.list_variables(tf_path)
117
+ names = []
118
+ arrays = []
119
+ for name, shape in init_vars:
120
+ array = tf.train.load_variable(tf_path, name)
121
+ names.append(name)
122
+ arrays.append(array)
123
+
124
+ for name, array in zip(names, arrays):
125
+ name = name.split("/")
126
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
127
+ # which are not required for using pretrained model
128
+ if any(
129
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer",
130
+ "AdamWeightDecayOptimizer_1", "global_step"]
131
+ for n in name
132
+ ):
133
+ continue
134
+ pointer = model
135
+ for m_name in name:
136
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
137
+ scope_names = re.split(r"_(\d+)", m_name)
138
+ else:
139
+ scope_names = [m_name]
140
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
141
+ pointer = getattr(pointer, "weight")
142
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
143
+ pointer = getattr(pointer, "bias")
144
+ elif scope_names[0] == "output_weights":
145
+ pointer = getattr(pointer, "weight")
146
+ elif scope_names[0] == "squad":
147
+ pointer = getattr(pointer, "classifier")
148
+ else:
149
+ try:
150
+ pointer = getattr(pointer, scope_names[0])
151
+ except AttributeError:
152
+ continue
153
+ if len(scope_names) >= 2:
154
+ num = int(scope_names[1])
155
+ pointer = pointer[num]
156
+ if m_name[-11:] == "_embeddings":
157
+ pointer = getattr(pointer, "weight")
158
+ elif m_name == "kernel":
159
+ array = np.transpose(array)
160
+ try:
161
+ assert (
162
+ pointer.shape == array.shape
163
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
164
+ except AssertionError as e:
165
+ e.args += (pointer.shape, array.shape)
166
+ raise
167
+ pointer.data = torch.from_numpy(array)
168
+ return model
169
+
170
+
171
+ class BertEmbeddings(nn.Module):
172
+ """Construct the embeddings from word, position and token_type embeddings."""
173
+
174
+ def __init__(self, config):
175
+ super().__init__()
176
+ self.word_embeddings = nn.Embedding(
177
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
178
+ self.position_embeddings = nn.Embedding(
179
+ config.max_position_embeddings, config.hidden_size)
180
+ self.token_type_embeddings = nn.Embedding(
181
+ config.type_vocab_size, config.hidden_size)
182
+
183
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
184
+ # any TensorFlow checkpoint file
185
+ self.LayerNorm = nn.LayerNorm(
186
+ config.hidden_size, eps=config.layer_norm_eps)
187
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
188
+
189
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
190
+ self.register_buffer("position_ids", torch.arange(
191
+ config.max_position_embeddings).expand((1, -1)))
192
+ self.position_embedding_type = getattr(
193
+ config, "position_embedding_type", "absolute")
194
+
195
+ self.config = config
196
+
197
+ def forward(
198
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
199
+ ):
200
+ if input_ids is not None:
201
+ input_shape = input_ids.size()
202
+ else:
203
+ input_shape = inputs_embeds.size()[:-1]
204
+
205
+ seq_length = input_shape[1]
206
+
207
+ if position_ids is None:
208
+ position_ids = self.position_ids[:,
209
+ past_key_values_length: seq_length + past_key_values_length]
210
+
211
+ if token_type_ids is None:
212
+ token_type_ids = torch.zeros(
213
+ input_shape, dtype=torch.long, device=self.position_ids.device)
214
+
215
+ if inputs_embeds is None:
216
+ inputs_embeds = self.word_embeddings(input_ids)
217
+
218
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
219
+
220
+ embeddings = inputs_embeds + token_type_embeddings
221
+ if self.position_embedding_type == "absolute":
222
+ position_embeddings = self.position_embeddings(position_ids)
223
+ embeddings += position_embeddings
224
+ embeddings = self.LayerNorm(embeddings)
225
+ embeddings = self.dropout(embeddings)
226
+ return embeddings
227
+
228
+
229
+ class BertSelfAttention(nn.Module):
230
+ def __init__(self, config, is_cross_attention):
231
+ super().__init__()
232
+ self.config = config
233
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
234
+ raise ValueError(
235
+ "The hidden size (%d) is not a multiple of the number of attention "
236
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
237
+ )
238
+
239
+ self.num_attention_heads = config.num_attention_heads
240
+ self.attention_head_size = int(
241
+ config.hidden_size / config.num_attention_heads)
242
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
243
+
244
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
245
+ if is_cross_attention:
246
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
247
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
248
+ else:
249
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
250
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
251
+
252
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
253
+ self.position_embedding_type = getattr(
254
+ config, "position_embedding_type", "absolute")
255
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
256
+ self.max_position_embeddings = config.max_position_embeddings
257
+ self.distance_embedding = nn.Embedding(
258
+ 2 * config.max_position_embeddings - 1, self.attention_head_size)
259
+ self.save_attention = False
260
+
261
+ def save_attn_gradients(self, attn_gradients):
262
+ self.attn_gradients = attn_gradients
263
+
264
+ def get_attn_gradients(self):
265
+ return self.attn_gradients
266
+
267
+ def save_attention_map(self, attention_map):
268
+ self.attention_map = attention_map
269
+
270
+ def get_attention_map(self):
271
+ return self.attention_map
272
+
273
+ def transpose_for_scores(self, x):
274
+ new_x_shape = x.size()[
275
+ :-1] + (self.num_attention_heads, self.attention_head_size)
276
+ x = x.view(*new_x_shape)
277
+ return x.permute(0, 2, 1, 3)
278
+
279
+ def forward(
280
+ self,
281
+ hidden_states,
282
+ attention_mask=None,
283
+ head_mask=None,
284
+ encoder_hidden_states=None,
285
+ encoder_attention_mask=None,
286
+ past_key_value=None,
287
+ output_attentions=False,
288
+ ):
289
+ mixed_query_layer = self.query(hidden_states)
290
+
291
+ # If this is instantiated as a cross-attention module, the keys
292
+ # and values come from an encoder; the attention mask needs to be
293
+ # such that the encoder's padding tokens are not attended to.
294
+ is_cross_attention = encoder_hidden_states is not None
295
+
296
+ if is_cross_attention:
297
+ key_layer = self.transpose_for_scores(
298
+ self.key(encoder_hidden_states))
299
+ value_layer = self.transpose_for_scores(
300
+ self.value(encoder_hidden_states))
301
+ attention_mask = encoder_attention_mask
302
+ elif past_key_value is not None:
303
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
304
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
305
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
306
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
307
+ else:
308
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
309
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
310
+
311
+ query_layer = self.transpose_for_scores(mixed_query_layer)
312
+
313
+ past_key_value = (key_layer, value_layer)
314
+
315
+ # Take the dot product between "query" and "key" to get the raw attention scores.
316
+ attention_scores = torch.matmul(
317
+ query_layer, key_layer.transpose(-1, -2))
318
+
319
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
320
+ seq_length = hidden_states.size()[1]
321
+ position_ids_l = torch.arange(
322
+ seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
323
+ position_ids_r = torch.arange(
324
+ seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
325
+ distance = position_ids_l - position_ids_r
326
+ positional_embedding = self.distance_embedding(
327
+ distance + self.max_position_embeddings - 1)
328
+ positional_embedding = positional_embedding.to(
329
+ dtype=query_layer.dtype) # fp16 compatibility
330
+
331
+ if self.position_embedding_type == "relative_key":
332
+ relative_position_scores = torch.einsum(
333
+ "bhld,lrd->bhlr", query_layer, positional_embedding)
334
+ attention_scores = attention_scores + relative_position_scores
335
+ elif self.position_embedding_type == "relative_key_query":
336
+ relative_position_scores_query = torch.einsum(
337
+ "bhld,lrd->bhlr", query_layer, positional_embedding)
338
+ relative_position_scores_key = torch.einsum(
339
+ "bhrd,lrd->bhlr", key_layer, positional_embedding)
340
+ attention_scores = attention_scores + \
341
+ relative_position_scores_query + relative_position_scores_key
342
+
343
+ attention_scores = attention_scores / \
344
+ math.sqrt(self.attention_head_size)
345
+ if attention_mask is not None:
346
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
347
+ attention_scores = attention_scores + attention_mask
348
+
349
+ # Normalize the attention scores to probabilities.
350
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
351
+
352
+ if is_cross_attention and self.save_attention:
353
+ self.save_attention_map(attention_probs)
354
+ attention_probs.register_hook(self.save_attn_gradients)
355
+
356
+ # This is actually dropping out entire tokens to attend to, which might
357
+ # seem a bit unusual, but is taken from the original Transformer paper.
358
+ attention_probs_dropped = self.dropout(attention_probs)
359
+
360
+ # Mask heads if we want to
361
+ if head_mask is not None:
362
+ attention_probs_dropped = attention_probs_dropped * head_mask
363
+
364
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
365
+
366
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
367
+ new_context_layer_shape = context_layer.size()[
368
+ :-2] + (self.all_head_size,)
369
+ context_layer = context_layer.view(*new_context_layer_shape)
370
+
371
+ # added `attention_scores` to return tuple
372
+ outputs = (context_layer, attention_probs, attention_scores) if output_attentions else (
373
+ context_layer,)
374
+
375
+ outputs = outputs + (past_key_value,)
376
+ return outputs
377
+
378
+
379
+ class BertSelfOutput(nn.Module):
380
+ def __init__(self, config):
381
+ super().__init__()
382
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
383
+ self.LayerNorm = nn.LayerNorm(
384
+ config.hidden_size, eps=config.layer_norm_eps)
385
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
386
+
387
+ def forward(self, hidden_states, input_tensor):
388
+ hidden_states = self.dense(hidden_states)
389
+ hidden_states = self.dropout(hidden_states)
390
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
391
+ return hidden_states
392
+
393
+
394
+ class BertAttention(nn.Module):
395
+ def __init__(self, config, is_cross_attention=False):
396
+ super().__init__()
397
+ self.self = BertSelfAttention(config, is_cross_attention)
398
+ self.output = BertSelfOutput(config)
399
+ self.pruned_heads = set()
400
+
401
+ def prune_heads(self, heads):
402
+ if len(heads) == 0:
403
+ return
404
+ heads, index = find_pruneable_heads_and_indices(
405
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
406
+ )
407
+
408
+ # Prune linear layers
409
+ self.self.query = prune_linear_layer(self.self.query, index)
410
+ self.self.key = prune_linear_layer(self.self.key, index)
411
+ self.self.value = prune_linear_layer(self.self.value, index)
412
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
413
+
414
+ # Update hyper params and store pruned heads
415
+ self.self.num_attention_heads = self.self.num_attention_heads - \
416
+ len(heads)
417
+ self.self.all_head_size = self.self.attention_head_size * \
418
+ self.self.num_attention_heads
419
+ self.pruned_heads = self.pruned_heads.union(heads)
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states,
424
+ attention_mask=None,
425
+ head_mask=None,
426
+ encoder_hidden_states=None,
427
+ encoder_attention_mask=None,
428
+ past_key_value=None,
429
+ output_attentions=False,
430
+ ):
431
+ self_outputs = self.self(
432
+ hidden_states,
433
+ attention_mask,
434
+ head_mask,
435
+ encoder_hidden_states,
436
+ encoder_attention_mask,
437
+ past_key_value,
438
+ output_attentions,
439
+ )
440
+ attention_output = self.output(self_outputs[0], hidden_states)
441
+ # add attentions if we output them
442
+ outputs = (attention_output,) + self_outputs[1:]
443
+ return outputs # (context_layer, attention_probs, attention_scores, past_key_value,)
444
+
445
+
446
+ class BertIntermediate(nn.Module):
447
+ def __init__(self, config):
448
+ super().__init__()
449
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
450
+ if isinstance(config.hidden_act, str):
451
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
452
+ else:
453
+ self.intermediate_act_fn = config.hidden_act
454
+
455
+ def forward(self, hidden_states):
456
+ hidden_states = self.dense(hidden_states)
457
+ hidden_states = self.intermediate_act_fn(hidden_states)
458
+ return hidden_states
459
+
460
+
461
+ class BertOutput(nn.Module):
462
+ def __init__(self, config):
463
+ super().__init__()
464
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
465
+ self.LayerNorm = nn.LayerNorm(
466
+ config.hidden_size, eps=config.layer_norm_eps)
467
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
468
+
469
+ def forward(self, hidden_states, input_tensor):
470
+ hidden_states = self.dense(hidden_states)
471
+ hidden_states = self.dropout(hidden_states)
472
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
473
+ return hidden_states
474
+
475
+
476
+ class BertLayer(nn.Module):
477
+ def __init__(self, config, layer_num, token_keep_rate=1.0):
478
+ super().__init__()
479
+ self.config = config
480
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
481
+ self.seq_len_dim = 1
482
+ self.attention = BertAttention(config)
483
+
484
+ self.has_cross_attention = (layer_num >= config.fusion_layer)
485
+ if self.has_cross_attention:
486
+ self.layer_num = layer_num
487
+ self.crossattention = BertAttention(
488
+ config, is_cross_attention=True)
489
+
490
+ # sparse params
491
+ self.token_keep_rate = token_keep_rate
492
+ self.token_keep_strategy = config.token_keep_strategy
493
+ self.encoder_num_cls_tokens = 1 # multiple cls tokens
494
+
495
+ self.intermediate = BertIntermediate(config)
496
+ self.output = BertOutput(config)
497
+
498
+ def sparsify(self, x, attn, mask=None):
499
+ x_cls, x_ = x[:, :self.encoder_num_cls_tokens], x[:, self.encoder_num_cls_tokens:]
500
+ assert 0 < self.token_keep_rate <= 1, "Expected keep rate in range (0, 1]"
501
+ left_tokens = math.ceil(self.token_keep_rate * x_.size(1))
502
+ if len(attn.shape) == 4:
503
+ attn = attn.mean(1) # pool over attention heads
504
+
505
+ if self.token_keep_strategy == 'cls_attn':
506
+ cls_attn = attn[:, 0, self.encoder_num_cls_tokens:]
507
+ _, idx = torch.topk(cls_attn, left_tokens, dim=1) # [B, left_tokens]
508
+
509
+ elif self.token_keep_strategy == 'avg_attn':
510
+ avg_attn = attn.mean(1)[:, self.encoder_num_cls_tokens:]
511
+ _, idx = torch.topk(avg_attn, left_tokens, dim=1) # [B, left_tokens]
512
+
513
+ elif self.token_keep_strategy == 'random':
514
+ rand = torch.rand(x_.shape[:2], device=x_.device)
515
+ _, idx = torch.topk(rand, left_tokens, dim=1) # [B, left_tokens]
516
+
517
+ else:
518
+ raise NotImplementedError(f"Sparse strategy {self.token_keep_strategy} is not implemented")
519
+
520
+ idx, _ = torch.sort(idx, dim=1)
521
+ index = idx.unsqueeze(-1).expand(-1, -1, x_.size(-1)) # [B, left_tokens, C]
522
+ outputs = torch.cat((x_cls, x_.gather(1, index)), dim=1).contiguous()
523
+ if mask is not None:
524
+ mask_cls, mask_ = mask[..., :self.encoder_num_cls_tokens], mask[..., self.encoder_num_cls_tokens:]
525
+ index = idx.unsqueeze(1).unsqueeze(1) # [B, 1, 1, left_tokens]
526
+ mask = torch.cat((mask_cls, mask_.gather(-1, index)), dim=-1).contiguous()
527
+ return outputs, mask, idx
528
+
529
+ def forward(
530
+ self,
531
+ hidden_states,
532
+ attention_mask=None,
533
+ head_mask=None,
534
+ encoder_hidden_states=None,
535
+ encoder_attention_mask=None,
536
+ past_key_value=None,
537
+ output_attentions=False,
538
+ ):
539
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
540
+ self_attn_past_key_value = past_key_value[:
541
+ 2] if past_key_value is not None else None
542
+ self_attention_outputs = self.attention(
543
+ hidden_states,
544
+ attention_mask,
545
+ head_mask,
546
+ output_attentions=output_attentions,
547
+ past_key_value=self_attn_past_key_value,
548
+ ) # (context_layer, attention_probs, attention_scores, past_key_value,)
549
+ attention_output = self_attention_outputs[0]
550
+
551
+ outputs = self_attention_outputs[1:-1]
552
+ present_key_value = self_attention_outputs[-1]
553
+
554
+ if self.has_cross_attention:
555
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
556
+ output_attentions = (output_attentions or self.token_keep_rate < 1)
557
+
558
+ if type(encoder_hidden_states) == list:
559
+ cross_attention_outputs = self.crossattention(
560
+ attention_output,
561
+ attention_mask,
562
+ head_mask,
563
+ encoder_hidden_states[(
564
+ self.layer_num-self.config.fusion_layer) % len(encoder_hidden_states)],
565
+ encoder_attention_mask[(
566
+ self.layer_num-self.config.fusion_layer) % len(encoder_hidden_states)],
567
+ output_attentions=output_attentions,
568
+ )
569
+ attention_output = cross_attention_outputs[0]
570
+ outputs = outputs + cross_attention_outputs[1:-1]
571
+
572
+ else:
573
+ cross_attention_outputs = self.crossattention(
574
+ attention_output,
575
+ attention_mask,
576
+ head_mask,
577
+ encoder_hidden_states,
578
+ encoder_attention_mask,
579
+ output_attentions=output_attentions,
580
+ ) # (context_layer, attention_probs, attention_scores, past_key_value,)
581
+ attention_output = cross_attention_outputs[0]
582
+
583
+ # add cross attentions if we output attention weights
584
+ outputs = outputs + cross_attention_outputs[1:-1]
585
+
586
+ # node sparsification
587
+ if self.token_keep_rate < 1:
588
+ encoder_hidden_states, encoder_attention_mask, token_keep_idx = self.sparsify(
589
+ encoder_hidden_states, cross_attention_outputs[1], encoder_attention_mask)
590
+ outputs = outputs + (encoder_hidden_states, encoder_attention_mask, token_keep_idx)
591
+
592
+ layer_output = apply_chunking_to_forward(
593
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
594
+ )
595
+ outputs = (layer_output,) + outputs
596
+
597
+ outputs = outputs + (present_key_value,)
598
+
599
+ return outputs
600
+
601
+ def feed_forward_chunk(self, attention_output):
602
+ intermediate_output = self.intermediate(attention_output)
603
+ layer_output = self.output(intermediate_output, attention_output)
604
+ return layer_output
605
+
606
+
607
+ class BertEncoder(nn.Module):
608
+ def __init__(self, config):
609
+ super().__init__()
610
+ self.config = config
611
+
612
+ # node sparsification
613
+ token_keep_rate = [1] * config.num_hidden_layers
614
+ for loc in config.token_drop_loc:
615
+ token_keep_rate[loc] = config.token_keep_rate
616
+
617
+ self.layer = nn.ModuleList([BertLayer(config, i, token_keep_rate[i])
618
+ for i in range(config.num_hidden_layers)])
619
+
620
+ def forward(
621
+ self,
622
+ hidden_states,
623
+ attention_mask=None,
624
+ head_mask=None,
625
+ encoder_hidden_states=None,
626
+ encoder_attention_mask=None,
627
+ past_key_values=None,
628
+ use_cache=None,
629
+ output_attentions=False,
630
+ output_hidden_states=False,
631
+ output_token_idx=False,
632
+ return_dict=True,
633
+ mode='multi_modal',
634
+ normalize_attention=True
635
+ ):
636
+ all_hidden_states = () if output_hidden_states else None
637
+ all_self_attentions = () if output_attentions else None
638
+ all_cross_attentions = () if output_attentions else None
639
+ all_token_idx = () if output_token_idx else None
640
+
641
+ next_decoder_cache = () if use_cache else None
642
+
643
+ if mode == 'text':
644
+ start_layer = 0
645
+ output_layer = self.config.fusion_layer
646
+
647
+ elif mode == 'fusion':
648
+ start_layer = self.config.fusion_layer
649
+ output_layer = self.config.num_hidden_layers
650
+
651
+ elif mode == 'multi_modal':
652
+ start_layer = 0
653
+ output_layer = self.config.num_hidden_layers
654
+
655
+ for i in range(start_layer, output_layer):
656
+ layer_module = self.layer[i]
657
+ if output_hidden_states:
658
+ all_hidden_states = all_hidden_states + (hidden_states,)
659
+
660
+ layer_head_mask = head_mask[i] if head_mask is not None else None
661
+ past_key_value = past_key_values[i] if past_key_values is not None else None
662
+
663
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
664
+
665
+ if use_cache:
666
+ use_cache = False
667
+
668
+ def create_custom_forward(module):
669
+ def custom_forward(*inputs):
670
+ return module(*inputs, past_key_value, output_attentions)
671
+
672
+ return custom_forward
673
+
674
+ layer_outputs = torch.utils.checkpoint.checkpoint(
675
+ create_custom_forward(layer_module),
676
+ hidden_states,
677
+ attention_mask,
678
+ layer_head_mask,
679
+ encoder_hidden_states,
680
+ encoder_attention_mask,
681
+ )
682
+ else:
683
+ layer_outputs = layer_module(
684
+ hidden_states,
685
+ attention_mask,
686
+ layer_head_mask,
687
+ encoder_hidden_states,
688
+ encoder_attention_mask,
689
+ past_key_value,
690
+ output_attentions,
691
+ ) # (context_layer, attention_probs, attention_scores, past_key_value,)
692
+ hidden_states = layer_outputs[0]
693
+ # update visual sequence
694
+ if mode == 'fusion' and layer_module.token_keep_rate < 1:
695
+ encoder_hidden_states, encoder_attention_mask, token_idx = layer_outputs[-4:-1]
696
+
697
+ if output_token_idx:
698
+ all_token_idx = all_token_idx + (token_idx,)
699
+
700
+ if use_cache:
701
+ next_decoder_cache += (layer_outputs[-1],)
702
+ if output_attentions:
703
+ # whether to output normalized attention,
704
+ # note for unnormalized attention, there is a mask added
705
+ offset = int(normalize_attention)
706
+ all_self_attentions = all_self_attentions + (layer_outputs[2-offset], )
707
+ if hasattr(layer_module, "crossattention"):
708
+ all_cross_attentions = all_cross_attentions + (layer_outputs[4-offset], )
709
+
710
+ if output_hidden_states:
711
+ all_hidden_states = all_hidden_states + (hidden_states,)
712
+
713
+ if not return_dict:
714
+ return tuple(
715
+ v
716
+ for v in [
717
+ hidden_states,
718
+ next_decoder_cache,
719
+ all_hidden_states,
720
+ all_self_attentions,
721
+ all_cross_attentions,
722
+ ]
723
+ if v is not None
724
+ )
725
+ return BertModelOutputWithPastAndCrossAttentions(
726
+ last_hidden_state=hidden_states,
727
+ past_key_values=next_decoder_cache,
728
+ hidden_states=all_hidden_states,
729
+ attentions=all_self_attentions,
730
+ cross_attentions=all_cross_attentions,
731
+ token_idx=all_token_idx
732
+ )
733
+
734
+
735
+ class BertPooler(nn.Module):
736
+ def __init__(self, config):
737
+ super().__init__()
738
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
739
+ self.activation = nn.Tanh()
740
+
741
+ def forward(self, hidden_states):
742
+ # We "pool" the model by simply taking the hidden state corresponding
743
+ # to the first token.
744
+ first_token_tensor = hidden_states[:, 0]
745
+ pooled_output = self.dense(first_token_tensor)
746
+ pooled_output = self.activation(pooled_output)
747
+ return pooled_output
748
+
749
+
750
+ class BertPredictionHeadTransform(nn.Module):
751
+ def __init__(self, config):
752
+ super().__init__()
753
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
754
+ if isinstance(config.hidden_act, str):
755
+ self.transform_act_fn = ACT2FN[config.hidden_act]
756
+ else:
757
+ self.transform_act_fn = config.hidden_act
758
+ self.LayerNorm = nn.LayerNorm(
759
+ config.hidden_size, eps=config.layer_norm_eps)
760
+
761
+ def forward(self, hidden_states):
762
+ hidden_states = self.dense(hidden_states)
763
+ hidden_states = self.transform_act_fn(hidden_states)
764
+ hidden_states = self.LayerNorm(hidden_states)
765
+ return hidden_states
766
+
767
+
768
+ class BertLMPredictionHead(nn.Module):
769
+ def __init__(self, config):
770
+ super().__init__()
771
+ self.transform = BertPredictionHeadTransform(config)
772
+
773
+ # The output weights are the same as the input embeddings, but there is
774
+ # an output-only bias for each token.
775
+ self.decoder = nn.Linear(
776
+ config.hidden_size, config.vocab_size, bias=False)
777
+
778
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
779
+
780
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
781
+ self.decoder.bias = self.bias
782
+
783
+ def forward(self, hidden_states):
784
+ hidden_states = self.transform(hidden_states)
785
+ hidden_states = self.decoder(hidden_states)
786
+ return hidden_states
787
+
788
+
789
+ class BertOnlyMLMHead(nn.Module):
790
+ def __init__(self, config):
791
+ super().__init__()
792
+ self.predictions = BertLMPredictionHead(config)
793
+
794
+ def forward(self, sequence_output):
795
+ prediction_scores = self.predictions(sequence_output)
796
+ return prediction_scores
797
+
798
+
799
+ class BertOnlyNSPHead(nn.Module):
800
+ def __init__(self, config):
801
+ super().__init__()
802
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
803
+
804
+ def forward(self, pooled_output):
805
+ seq_relationship_score = self.seq_relationship(pooled_output)
806
+ return seq_relationship_score
807
+
808
+
809
+ class BertPreTrainingHeads(nn.Module):
810
+ def __init__(self, config):
811
+ super().__init__()
812
+ self.predictions = BertLMPredictionHead(config)
813
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
814
+
815
+ def forward(self, sequence_output, pooled_output):
816
+ prediction_scores = self.predictions(sequence_output)
817
+ seq_relationship_score = self.seq_relationship(pooled_output)
818
+ return prediction_scores, seq_relationship_score
819
+
820
+
821
+ class BertPreTrainedModel(PreTrainedModel):
822
+ """
823
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
824
+ models.
825
+ """
826
+
827
+ config_class = BertConfig
828
+ load_tf_weights = load_tf_weights_in_bert
829
+ base_model_prefix = "bert"
830
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
831
+
832
+ def _init_weights(self, module):
833
+ """ Initialize the weights """
834
+ if isinstance(module, (nn.Linear, nn.Embedding)):
835
+ # Slightly different from the TF version which uses truncated_normal for initialization
836
+ # cf https://github.com/pytorch/pytorch/pull/5617
837
+ module.weight.data.normal_(
838
+ mean=0.0, std=self.config.initializer_range)
839
+ elif isinstance(module, nn.LayerNorm):
840
+ module.bias.data.zero_()
841
+ module.weight.data.fill_(1.0)
842
+ if isinstance(module, nn.Linear) and module.bias is not None:
843
+ module.bias.data.zero_()
844
+
845
+
846
+ @dataclass
847
+ class BertForPreTrainingOutput(ModelOutput):
848
+ """
849
+ Output type of :class:`~transformers.BertForPreTraining`.
850
+ Args:
851
+ loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
852
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
853
+ (classification) loss.
854
+ prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
855
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
856
+ seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
857
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
858
+ before SoftMax).
859
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
860
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
861
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
862
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
863
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
864
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
865
+ sequence_length, sequence_length)`.
866
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
867
+ heads.
868
+ """
869
+
870
+ loss: Optional[torch.FloatTensor] = None
871
+ prediction_logits: torch.FloatTensor = None
872
+ seq_relationship_logits: torch.FloatTensor = None
873
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
874
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
875
+
876
+
877
+ BERT_START_DOCSTRING = r"""
878
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
879
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
880
+ pruning heads etc.)
881
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
882
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
883
+ general usage and behavior.
884
+ Parameters:
885
+ config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
886
+ Initializing with a config file does not load the weights associated with the model, only the
887
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
888
+ weights.
889
+ """
890
+
891
+ BERT_INPUTS_DOCSTRING = r"""
892
+ Args:
893
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
894
+ Indices of input sequence tokens in the vocabulary.
895
+ Indices can be obtained using :class:`~transformers.BertTokenizer`. See
896
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
897
+ details.
898
+ `What are input IDs? <../glossary.html#input-ids>`__
899
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
900
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
901
+ - 1 for tokens that are **not masked**,
902
+ - 0 for tokens that are **masked**.
903
+ `What are attention masks? <../glossary.html#attention-mask>`__
904
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
905
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
906
+ 1]``:
907
+ - 0 corresponds to a `sentence A` token,
908
+ - 1 corresponds to a `sentence B` token.
909
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
910
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
911
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
912
+ config.max_position_embeddings - 1]``.
913
+ `What are position IDs? <../glossary.html#position-ids>`_
914
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
915
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
916
+ - 1 indicates the head is **not masked**,
917
+ - 0 indicates the head is **masked**.
918
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
919
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
920
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
921
+ vectors than the model's internal embedding lookup matrix.
922
+ output_attentions (:obj:`bool`, `optional`):
923
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
924
+ tensors for more detail.
925
+ output_hidden_states (:obj:`bool`, `optional`):
926
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
927
+ more detail.
928
+ return_dict (:obj:`bool`, `optional`):
929
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
930
+ """
931
+
932
+
933
+ @add_start_docstrings(
934
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
935
+ BERT_START_DOCSTRING,
936
+ )
937
+ class BertModel(BertPreTrainedModel):
938
+ """
939
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
940
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
941
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
942
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
943
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
944
+ input to the forward pass.
945
+ """
946
+
947
+ def __init__(self, config, add_pooling_layer=True):
948
+ super().__init__(config)
949
+ self.config = config
950
+
951
+ self.embeddings = BertEmbeddings(config)
952
+
953
+ self.encoder = BertEncoder(config)
954
+
955
+ self.pooler = BertPooler(config) if add_pooling_layer else None
956
+
957
+ self.init_weights()
958
+
959
+ def get_input_embeddings(self):
960
+ return self.embeddings.word_embeddings
961
+
962
+ def set_input_embeddings(self, value):
963
+ self.embeddings.word_embeddings = value
964
+
965
+ def _prune_heads(self, heads_to_prune):
966
+ """
967
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
968
+ class PreTrainedModel
969
+ """
970
+ for layer, heads in heads_to_prune.items():
971
+ self.encoder.layer[layer].attention.prune_heads(heads)
972
+
973
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
974
+ """
975
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
976
+
977
+ Arguments:
978
+ attention_mask (:obj:`torch.Tensor`):
979
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
980
+ input_shape (:obj:`Tuple[int]`):
981
+ The shape of the input to the model.
982
+ device: (:obj:`torch.device`):
983
+ The device of the input to the model.
984
+
985
+ Returns:
986
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
987
+ """
988
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
989
+ # ourselves in which case we just need to make it broadcastable to all heads.
990
+ if attention_mask.dim() == 3:
991
+ extended_attention_mask = attention_mask[:, None, :, :]
992
+ elif attention_mask.dim() == 2:
993
+ # Provided a padding mask of dimensions [batch_size, seq_length]
994
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
995
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
996
+ if is_decoder:
997
+ batch_size, seq_length = input_shape
998
+ seq_ids = torch.arange(seq_length, device=device)
999
+ causal_mask = seq_ids[None, None, :].repeat(
1000
+ batch_size, seq_length, 1) <= seq_ids[None, :, None]
1001
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
1002
+ # causal and attention masks must have same type with pytorch version < 1.3
1003
+ causal_mask = causal_mask.to(attention_mask.dtype)
1004
+
1005
+ if causal_mask.shape[1] < attention_mask.shape[1]:
1006
+ prefix_seq_len = attention_mask.shape[1] - \
1007
+ causal_mask.shape[1]
1008
+ causal_mask = torch.cat(
1009
+ [
1010
+ torch.ones(
1011
+ (batch_size, seq_length,
1012
+ prefix_seq_len), device=device, dtype=causal_mask.dtype
1013
+ ),
1014
+ causal_mask,
1015
+ ],
1016
+ axis=-1,
1017
+ )
1018
+
1019
+ extended_attention_mask = causal_mask[:, None,
1020
+ :, :] * attention_mask[:, None, None, :]
1021
+ else:
1022
+ extended_attention_mask = attention_mask[:, None, None, :]
1023
+ else:
1024
+ raise ValueError(
1025
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
1026
+ input_shape, attention_mask.shape
1027
+ )
1028
+ )
1029
+
1030
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
1031
+ # masked positions, this operation will create a tensor which is 0.0 for
1032
+ # positions we want to attend and -10000.0 for masked positions.
1033
+ # Since we are adding it to the raw scores before the softmax, this is
1034
+ # effectively the same as removing these entirely.
1035
+ extended_attention_mask = extended_attention_mask.to(
1036
+ dtype=self.dtype) # fp16 compatibility
1037
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
1038
+ return extended_attention_mask
1039
+
1040
+ def forward(
1041
+ self,
1042
+ input_ids=None,
1043
+ attention_mask=None,
1044
+ token_type_ids=None,
1045
+ position_ids=None,
1046
+ head_mask=None,
1047
+ inputs_embeds=None,
1048
+ encoder_embeds=None,
1049
+ encoder_hidden_states=None,
1050
+ encoder_attention_mask=None,
1051
+ past_key_values=None,
1052
+ use_cache=None,
1053
+ output_attentions=None,
1054
+ output_hidden_states=None,
1055
+ output_token_idx=None,
1056
+ return_dict=None,
1057
+ is_decoder=False,
1058
+ mode='multi_modal',
1059
+ normalize_attention=True,
1060
+ ):
1061
+ r"""
1062
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1063
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1064
+ the model is configured as a decoder.
1065
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1066
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1067
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1068
+ - 1 for tokens that are **not masked**,
1069
+ - 0 for tokens that are **masked**.
1070
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1071
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1072
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1073
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1074
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1075
+ use_cache (:obj:`bool`, `optional`):
1076
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1077
+ decoding (see :obj:`past_key_values`).
1078
+ """
1079
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1080
+ output_hidden_states = (
1081
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1082
+ )
1083
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1084
+
1085
+ if is_decoder:
1086
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1087
+ else:
1088
+ use_cache = False
1089
+
1090
+ if input_ids is not None and inputs_embeds is not None:
1091
+ raise ValueError(
1092
+ "You cannot specify both input_ids and inputs_embeds at the same time")
1093
+ elif input_ids is not None:
1094
+ input_shape = input_ids.size()
1095
+ batch_size, seq_length = input_shape
1096
+ device = input_ids.device
1097
+ elif inputs_embeds is not None:
1098
+ input_shape = inputs_embeds.size()[:-1]
1099
+ batch_size, seq_length = input_shape
1100
+ device = inputs_embeds.device
1101
+ elif encoder_embeds is not None:
1102
+ input_shape = encoder_embeds.size()[:-1]
1103
+ batch_size, seq_length = input_shape
1104
+ device = encoder_embeds.device
1105
+ else:
1106
+ raise ValueError(
1107
+ "You have to specify either input_ids or inputs_embeds or encoder_embeds")
1108
+
1109
+ # past_key_values_length
1110
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1111
+
1112
+ if attention_mask is None:
1113
+ attention_mask = torch.ones(
1114
+ ((batch_size, seq_length + past_key_values_length)), device=device)
1115
+ if token_type_ids is None:
1116
+ token_type_ids = torch.zeros(
1117
+ input_shape, dtype=torch.long, device=device)
1118
+
1119
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1120
+ # ourselves in which case we just need to make it broadcastable to all heads.
1121
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
1122
+ device, is_decoder)
1123
+
1124
+ # If a 2D or 3D attention mask is provided for the cross-attention
1125
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1126
+ if encoder_hidden_states is not None:
1127
+ if type(encoder_hidden_states) == list:
1128
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size(
1129
+ )
1130
+ else:
1131
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1132
+ encoder_hidden_shape = (
1133
+ encoder_batch_size, encoder_sequence_length)
1134
+
1135
+ if type(encoder_attention_mask) == list:
1136
+ encoder_extended_attention_mask = [
1137
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask]
1138
+ elif encoder_attention_mask is None:
1139
+ encoder_attention_mask = torch.ones(
1140
+ encoder_hidden_shape, device=device)
1141
+ encoder_extended_attention_mask = self.invert_attention_mask(
1142
+ encoder_attention_mask)
1143
+ else:
1144
+ encoder_extended_attention_mask = self.invert_attention_mask(
1145
+ encoder_attention_mask)
1146
+ else:
1147
+ encoder_extended_attention_mask = None
1148
+
1149
+ # Prepare head mask if needed
1150
+ # 1.0 in head_mask indicate we keep the head
1151
+ # attention_probs has shape bsz x n_heads x N x N
1152
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1153
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1154
+ head_mask = self.get_head_mask(
1155
+ head_mask, self.config.num_hidden_layers)
1156
+
1157
+ if encoder_embeds is None:
1158
+ embedding_output = self.embeddings(
1159
+ input_ids=input_ids,
1160
+ position_ids=position_ids,
1161
+ token_type_ids=token_type_ids,
1162
+ inputs_embeds=inputs_embeds,
1163
+ past_key_values_length=past_key_values_length,
1164
+ )
1165
+ else:
1166
+ embedding_output = encoder_embeds
1167
+
1168
+ encoder_outputs = self.encoder(
1169
+ embedding_output,
1170
+ attention_mask=extended_attention_mask,
1171
+ head_mask=head_mask,
1172
+ encoder_hidden_states=encoder_hidden_states,
1173
+ encoder_attention_mask=encoder_extended_attention_mask,
1174
+ past_key_values=past_key_values,
1175
+ use_cache=use_cache,
1176
+ output_attentions=output_attentions,
1177
+ output_hidden_states=output_hidden_states,
1178
+ output_token_idx=output_token_idx,
1179
+ return_dict=return_dict,
1180
+ mode=mode,
1181
+ normalize_attention=normalize_attention,
1182
+
1183
+ )
1184
+ sequence_output = encoder_outputs[0]
1185
+ pooled_output = self.pooler(
1186
+ sequence_output) if self.pooler is not None else None
1187
+
1188
+ if not return_dict:
1189
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1190
+
1191
+ return BertModelOutputWithPoolingAndCrossAttentions(
1192
+ last_hidden_state=sequence_output,
1193
+ pooler_output=pooled_output,
1194
+ past_key_values=encoder_outputs.past_key_values,
1195
+ hidden_states=encoder_outputs.hidden_states,
1196
+ attentions=encoder_outputs.attentions,
1197
+ cross_attentions=encoder_outputs.cross_attentions,
1198
+ token_idx=encoder_outputs.token_idx,
1199
+ )
1200
+
1201
+
1202
+ @add_start_docstrings(
1203
+ """
1204
+ Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
1205
+ sentence prediction (classification)` head.
1206
+ """,
1207
+ BERT_START_DOCSTRING,
1208
+ )
1209
+ class BertForPreTraining(BertPreTrainedModel):
1210
+ def __init__(self, config):
1211
+ super().__init__(config)
1212
+
1213
+ self.bert = BertModel(config)
1214
+ self.cls = BertPreTrainingHeads(config)
1215
+
1216
+ self.init_weights()
1217
+
1218
+ def get_output_embeddings(self):
1219
+ return self.cls.predictions.decoder
1220
+
1221
+ def set_output_embeddings(self, new_embeddings):
1222
+ self.cls.predictions.decoder = new_embeddings
1223
+
1224
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1225
+ @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1226
+ def forward(
1227
+ self,
1228
+ input_ids=None,
1229
+ attention_mask=None,
1230
+ token_type_ids=None,
1231
+ position_ids=None,
1232
+ head_mask=None,
1233
+ inputs_embeds=None,
1234
+ labels=None,
1235
+ next_sentence_label=None,
1236
+ output_attentions=None,
1237
+ output_hidden_states=None,
1238
+ return_dict=None,
1239
+ ):
1240
+ r"""
1241
+ labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`):
1242
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1243
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1244
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1245
+ next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
1246
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1247
+ (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
1248
+ - 0 indicates sequence B is a continuation of sequence A,
1249
+ - 1 indicates sequence B is a random sequence.
1250
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
1251
+ Used to hide legacy arguments that have been deprecated.
1252
+ Returns:
1253
+ Example::
1254
+ >>> from transformers import BertTokenizer, BertForPreTraining
1255
+ >>> import torch
1256
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1257
+ >>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
1258
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1259
+ >>> outputs = model(**inputs)
1260
+ >>> prediction_logits = outputs.prediction_logits
1261
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
1262
+ """
1263
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1264
+
1265
+ outputs = self.bert(
1266
+ input_ids,
1267
+ attention_mask=attention_mask,
1268
+ token_type_ids=token_type_ids,
1269
+ position_ids=position_ids,
1270
+ head_mask=head_mask,
1271
+ inputs_embeds=inputs_embeds,
1272
+ output_attentions=output_attentions,
1273
+ output_hidden_states=output_hidden_states,
1274
+ return_dict=return_dict,
1275
+ )
1276
+
1277
+ sequence_output, pooled_output = outputs[:2]
1278
+ prediction_scores, seq_relationship_score = self.cls(
1279
+ sequence_output, pooled_output)
1280
+
1281
+ total_loss = None
1282
+ if labels is not None and next_sentence_label is not None:
1283
+ loss_fct = CrossEntropyLoss()
1284
+ masked_lm_loss = loss_fct(
1285
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1286
+ next_sentence_loss = loss_fct(
1287
+ seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1288
+ total_loss = masked_lm_loss + next_sentence_loss
1289
+
1290
+ if not return_dict:
1291
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
1292
+ return ((total_loss,) + output) if total_loss is not None else output
1293
+
1294
+ return BertForPreTrainingOutput(
1295
+ loss=total_loss,
1296
+ prediction_logits=prediction_scores,
1297
+ seq_relationship_logits=seq_relationship_score,
1298
+ hidden_states=outputs.hidden_states,
1299
+ attentions=outputs.attentions,
1300
+ )
1301
+
1302
+
1303
+ @add_start_docstrings(
1304
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
1305
+ )
1306
+ class BertLMHeadModel(BertPreTrainedModel):
1307
+
1308
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1309
+ _keys_to_ignore_on_load_missing = [
1310
+ r"position_ids", r"predictions.decoder.bias"]
1311
+
1312
+ def __init__(self, config):
1313
+ super().__init__(config)
1314
+
1315
+ self.bert = BertModel(config, add_pooling_layer=False)
1316
+ self.cls = BertOnlyMLMHead(config)
1317
+
1318
+ self.init_weights()
1319
+
1320
+ def get_output_embeddings(self):
1321
+ return self.cls.predictions.decoder
1322
+
1323
+ def set_output_embeddings(self, new_embeddings):
1324
+ self.cls.predictions.decoder = new_embeddings
1325
+
1326
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1327
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1328
+ def forward(
1329
+ self,
1330
+ input_ids=None,
1331
+ attention_mask=None,
1332
+ token_type_ids=None,
1333
+ position_ids=None,
1334
+ head_mask=None,
1335
+ inputs_embeds=None,
1336
+ encoder_hidden_states=None,
1337
+ encoder_attention_mask=None,
1338
+ labels=None,
1339
+ past_key_values=None,
1340
+ use_cache=None,
1341
+ output_attentions=None,
1342
+ output_hidden_states=None,
1343
+ return_dict=None,
1344
+ is_decoder=True,
1345
+ reduction='mean',
1346
+ mode='multi_modal',
1347
+ normalize_attention=True,
1348
+ soft_labels=None,
1349
+ alpha=0,
1350
+ return_logits=False,
1351
+ ):
1352
+ r"""
1353
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1354
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1355
+ the model is configured as a decoder.
1356
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1357
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1358
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1359
+ - 1 for tokens that are **not masked**,
1360
+ - 0 for tokens that are **masked**.
1361
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1362
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1363
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1364
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1365
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1366
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1367
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1368
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1369
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1370
+ use_cache (:obj:`bool`, `optional`):
1371
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1372
+ decoding (see :obj:`past_key_values`).
1373
+ Returns:
1374
+ Example::
1375
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1376
+ >>> import torch
1377
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1378
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1379
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1380
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1381
+ >>> outputs = model(**inputs)
1382
+ >>> prediction_logits = outputs.logits
1383
+ """
1384
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1385
+ if labels is not None:
1386
+ use_cache = False
1387
+
1388
+ outputs = self.bert(
1389
+ input_ids,
1390
+ attention_mask=attention_mask,
1391
+ token_type_ids=token_type_ids,
1392
+ position_ids=position_ids,
1393
+ head_mask=head_mask,
1394
+ inputs_embeds=inputs_embeds,
1395
+ encoder_hidden_states=encoder_hidden_states,
1396
+ encoder_attention_mask=encoder_attention_mask,
1397
+ past_key_values=past_key_values,
1398
+ use_cache=use_cache,
1399
+ output_attentions=output_attentions,
1400
+ output_hidden_states=output_hidden_states,
1401
+ return_dict=return_dict,
1402
+ is_decoder=is_decoder,
1403
+ mode=mode,
1404
+ normalize_attention=normalize_attention,
1405
+ )
1406
+
1407
+ sequence_output = outputs[0]
1408
+ prediction_scores = self.cls(sequence_output)
1409
+
1410
+ if return_logits:
1411
+ return prediction_scores[:, :-1, :].contiguous()
1412
+
1413
+ lm_loss = None
1414
+ if labels is not None:
1415
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1416
+ shifted_prediction_scores = prediction_scores[:,
1417
+ :-1, :].contiguous()
1418
+ labels = labels[:, 1:].contiguous()
1419
+ loss_fct = CrossEntropyLoss(reduction=reduction)
1420
+ lm_loss = loss_fct(
1421
+ shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1422
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1423
+
1424
+ if soft_labels is not None:
1425
+ loss_distill = - \
1426
+ torch.sum(F.log_softmax(shifted_prediction_scores,
1427
+ dim=1)*soft_labels, dim=-1)
1428
+ loss_distill = (loss_distill * (labels != -100)).sum(1)
1429
+ lm_loss = (1-alpha)*lm_loss + alpha*loss_distill
1430
+
1431
+ if not return_dict:
1432
+ output = (prediction_scores,) + outputs[2:]
1433
+ return ((lm_loss,) + output) if lm_loss is not None else output
1434
+
1435
+ return CausalLMOutputWithCrossAttentions(
1436
+ loss=lm_loss,
1437
+ logits=prediction_scores,
1438
+ past_key_values=outputs.past_key_values,
1439
+ hidden_states=outputs.hidden_states,
1440
+ attentions=outputs.attentions,
1441
+ cross_attentions=outputs.cross_attentions,
1442
+ )
1443
+
1444
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
1445
+ input_shape = input_ids.shape
1446
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1447
+ if attention_mask is None:
1448
+ attention_mask = input_ids.new_ones(input_shape)
1449
+
1450
+ # cut decoder_input_ids if past is used
1451
+ if past is not None:
1452
+ input_ids = input_ids[:, -1:]
1453
+
1454
+ return {
1455
+ "input_ids": input_ids,
1456
+ "attention_mask": attention_mask,
1457
+ "past_key_values": past,
1458
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1459
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1460
+ "is_decoder": True,
1461
+ }
1462
+
1463
+ def _reorder_cache(self, past, beam_idx):
1464
+ reordered_past = ()
1465
+ for layer_past in past:
1466
+ reordered_past += (tuple(past_state.index_select(0, beam_idx)
1467
+ for past_state in layer_past),)
1468
+ return reordered_past
1469
+
1470
+
1471
+ @dataclass
1472
+ class MaskedLMOutputWithDistill(MaskedLMOutput):
1473
+ loss_aux: Optional[torch.FloatTensor] = None
1474
+ loss_distill: Optional[torch.FloatTensor] = None
1475
+
1476
+
1477
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
1478
+ class BertForMaskedLM(BertPreTrainedModel):
1479
+
1480
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1481
+ _keys_to_ignore_on_load_missing = [
1482
+ r"position_ids", r"predictions.decoder.bias"]
1483
+
1484
+ def __init__(self, config):
1485
+ super().__init__(config)
1486
+
1487
+ self.bert = BertModel(config, add_pooling_layer=False)
1488
+ self.cls = BertOnlyMLMHead(config)
1489
+
1490
+ self.init_weights()
1491
+
1492
+ def tie_aux_decoder_weights(self, module, aux_modules):
1493
+ """Tie decoder weights of all `aux_modules` to `module`, (not bias)"""
1494
+ for m in aux_modules:
1495
+ m.predictions.decoder.weight = module.predictions.decoder.weight
1496
+
1497
+ def get_output_embeddings(self):
1498
+ return self.cls.predictions.decoder
1499
+
1500
+ def set_output_embeddings(self, new_embeddings):
1501
+ self.cls.predictions.decoder = new_embeddings
1502
+
1503
+ def forward(
1504
+ self,
1505
+ input_ids=None,
1506
+ attention_mask=None,
1507
+ token_type_ids=None,
1508
+ position_ids=None,
1509
+ head_mask=None,
1510
+ inputs_embeds=None,
1511
+ encoder_embeds=None,
1512
+ encoder_hidden_states=None,
1513
+ encoder_attention_mask=None,
1514
+ labels=None,
1515
+ output_attentions=None,
1516
+ output_hidden_states=None,
1517
+ return_dict=None,
1518
+ is_decoder=False,
1519
+ mode='multi_modal',
1520
+ normalize_attention=True,
1521
+ soft_labels=None,
1522
+ alpha=0,
1523
+ return_logits=False,
1524
+ ):
1525
+ r"""
1526
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1527
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1528
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1529
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1530
+ """
1531
+
1532
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1533
+
1534
+ outputs = self.bert(
1535
+ input_ids,
1536
+ attention_mask=attention_mask,
1537
+ token_type_ids=token_type_ids,
1538
+ position_ids=position_ids,
1539
+ head_mask=head_mask,
1540
+ inputs_embeds=inputs_embeds,
1541
+ encoder_embeds=encoder_embeds,
1542
+ encoder_hidden_states=encoder_hidden_states,
1543
+ encoder_attention_mask=encoder_attention_mask,
1544
+ output_attentions=output_attentions,
1545
+ output_hidden_states=output_hidden_states,
1546
+ return_dict=return_dict,
1547
+ is_decoder=is_decoder,
1548
+ mode=mode,
1549
+ normalize_attention=normalize_attention
1550
+ )
1551
+
1552
+ sequence_output = outputs[0]
1553
+ prediction_scores = self.cls(sequence_output)
1554
+
1555
+ if return_logits:
1556
+ return prediction_scores
1557
+
1558
+ masked_lm_loss = None
1559
+ masked_lm_loss_aux = 0.
1560
+ if labels is not None:
1561
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1562
+ masked_lm_loss = loss_fct(
1563
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1564
+
1565
+ if soft_labels is not None:
1566
+ loss_distill = - \
1567
+ torch.sum(F.log_softmax(prediction_scores, dim=1)
1568
+ * soft_labels, dim=-1)
1569
+ loss_distill = loss_distill[labels != -100].mean()
1570
+ masked_lm_loss = (1-alpha)*masked_lm_loss + alpha*loss_distill
1571
+
1572
+ if not return_dict:
1573
+ output = (prediction_scores,) + outputs[2:]
1574
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1575
+
1576
+ # changed from MaskedLMOutput to MaskedLMOutputWithDistill
1577
+ return MaskedLMOutputWithDistill(
1578
+ loss=masked_lm_loss,
1579
+ loss_aux=masked_lm_loss_aux,
1580
+ logits=prediction_scores,
1581
+ hidden_states=outputs.hidden_states,
1582
+ attentions=outputs.attentions,
1583
+ )
1584
+
1585
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1586
+ input_shape = input_ids.shape
1587
+ effective_batch_size = input_shape[0]
1588
+
1589
+ # add a dummy token
1590
+ assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
1591
+ attention_mask = torch.cat(
1592
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1593
+ dummy_token = torch.full(
1594
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1595
+ )
1596
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1597
+
1598
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1599
+
1600
+
1601
+ @add_start_docstrings(
1602
+ """Bert Model with a `next sentence prediction (classification)` head on top. """,
1603
+ BERT_START_DOCSTRING,
1604
+ )
1605
+ class BertForNextSentencePrediction(BertPreTrainedModel):
1606
+ def __init__(self, config):
1607
+ super().__init__(config)
1608
+
1609
+ self.bert = BertModel(config)
1610
+ self.cls = BertOnlyNSPHead(config)
1611
+
1612
+ self.init_weights()
1613
+
1614
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1615
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
1616
+ def forward(
1617
+ self,
1618
+ input_ids=None,
1619
+ attention_mask=None,
1620
+ token_type_ids=None,
1621
+ position_ids=None,
1622
+ head_mask=None,
1623
+ inputs_embeds=None,
1624
+ labels=None,
1625
+ output_attentions=None,
1626
+ output_hidden_states=None,
1627
+ return_dict=None,
1628
+ **kwargs
1629
+ ):
1630
+ r"""
1631
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1632
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1633
+ (see ``input_ids`` docstring). Indices should be in ``[0, 1]``:
1634
+ - 0 indicates sequence B is a continuation of sequence A,
1635
+ - 1 indicates sequence B is a random sequence.
1636
+ Returns:
1637
+ Example::
1638
+ >>> from transformers import BertTokenizer, BertForNextSentencePrediction
1639
+ >>> import torch
1640
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1641
+ >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
1642
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1643
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1644
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
1645
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
1646
+ >>> logits = outputs.logits
1647
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1648
+ """
1649
+
1650
+ if "next_sentence_label" in kwargs:
1651
+ warnings.warn(
1652
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
1653
+ FutureWarning,
1654
+ )
1655
+ labels = kwargs.pop("next_sentence_label")
1656
+
1657
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1658
+
1659
+ outputs = self.bert(
1660
+ input_ids,
1661
+ attention_mask=attention_mask,
1662
+ token_type_ids=token_type_ids,
1663
+ position_ids=position_ids,
1664
+ head_mask=head_mask,
1665
+ inputs_embeds=inputs_embeds,
1666
+ output_attentions=output_attentions,
1667
+ output_hidden_states=output_hidden_states,
1668
+ return_dict=return_dict,
1669
+ )
1670
+
1671
+ pooled_output = outputs[1]
1672
+
1673
+ seq_relationship_scores = self.cls(pooled_output)
1674
+
1675
+ next_sentence_loss = None
1676
+ if labels is not None:
1677
+ loss_fct = CrossEntropyLoss()
1678
+ next_sentence_loss = loss_fct(
1679
+ seq_relationship_scores.view(-1, 2), labels.view(-1))
1680
+
1681
+ if not return_dict:
1682
+ output = (seq_relationship_scores,) + outputs[2:]
1683
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1684
+
1685
+ return NextSentencePredictorOutput(
1686
+ loss=next_sentence_loss,
1687
+ logits=seq_relationship_scores,
1688
+ hidden_states=outputs.hidden_states,
1689
+ attentions=outputs.attentions,
1690
+ )
1691
+
1692
+
1693
+ @add_start_docstrings(
1694
+ """
1695
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1696
+ output) e.g. for GLUE tasks.
1697
+ """,
1698
+ BERT_START_DOCSTRING,
1699
+ )
1700
+ class BertForSequenceClassification(BertPreTrainedModel):
1701
+ def __init__(self, config):
1702
+ super().__init__(config)
1703
+ self.num_labels = config.num_labels
1704
+
1705
+ self.bert = BertModel(config)
1706
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1707
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1708
+
1709
+ self.init_weights()
1710
+
1711
+ def forward(
1712
+ self,
1713
+ input_ids=None,
1714
+ attention_mask=None,
1715
+ token_type_ids=None,
1716
+ position_ids=None,
1717
+ head_mask=None,
1718
+ inputs_embeds=None,
1719
+ labels=None,
1720
+ output_attentions=None,
1721
+ output_hidden_states=None,
1722
+ return_dict=None,
1723
+ ):
1724
+ r"""
1725
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1726
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1727
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1728
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1729
+ """
1730
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1731
+
1732
+ outputs = self.bert(
1733
+ input_ids,
1734
+ attention_mask=attention_mask,
1735
+ token_type_ids=token_type_ids,
1736
+ position_ids=position_ids,
1737
+ head_mask=head_mask,
1738
+ inputs_embeds=inputs_embeds,
1739
+ output_attentions=output_attentions,
1740
+ output_hidden_states=output_hidden_states,
1741
+ return_dict=return_dict,
1742
+ )
1743
+
1744
+ pooled_output = outputs[1]
1745
+
1746
+ pooled_output = self.dropout(pooled_output)
1747
+ logits = self.classifier(pooled_output)
1748
+
1749
+ loss = None
1750
+ if labels is not None:
1751
+ if self.num_labels == 1:
1752
+ # We are doing regression
1753
+ loss_fct = MSELoss()
1754
+ loss = loss_fct(logits.view(-1), labels.view(-1))
1755
+ else:
1756
+ loss_fct = CrossEntropyLoss()
1757
+ loss = loss_fct(
1758
+ logits.view(-1, self.num_labels), labels.view(-1))
1759
+
1760
+ if not return_dict:
1761
+ output = (logits,) + outputs[2:]
1762
+ return ((loss,) + output) if loss is not None else output
1763
+
1764
+ return SequenceClassifierOutput(
1765
+ loss=loss,
1766
+ logits=logits,
1767
+ hidden_states=outputs.hidden_states,
1768
+ attentions=outputs.attentions,
1769
+ )
1770
+
1771
+
1772
+ @add_start_docstrings(
1773
+ """
1774
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1775
+ softmax) e.g. for RocStories/SWAG tasks.
1776
+ """,
1777
+ BERT_START_DOCSTRING,
1778
+ )
1779
+ class BertForMultipleChoice(BertPreTrainedModel):
1780
+ def __init__(self, config):
1781
+ super().__init__(config)
1782
+
1783
+ self.bert = BertModel(config)
1784
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1785
+ self.classifier = nn.Linear(config.hidden_size, 1)
1786
+
1787
+ self.init_weights()
1788
+
1789
+ def forward(
1790
+ self,
1791
+ input_ids=None,
1792
+ attention_mask=None,
1793
+ token_type_ids=None,
1794
+ position_ids=None,
1795
+ head_mask=None,
1796
+ inputs_embeds=None,
1797
+ labels=None,
1798
+ output_attentions=None,
1799
+ output_hidden_states=None,
1800
+ return_dict=None,
1801
+ ):
1802
+ r"""
1803
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1804
+ Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
1805
+ num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
1806
+ :obj:`input_ids` above)
1807
+ """
1808
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1809
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1810
+
1811
+ input_ids = input_ids.view(-1, input_ids.size(-1)
1812
+ ) if input_ids is not None else None
1813
+ attention_mask = attention_mask.view(
1814
+ -1, attention_mask.size(-1)) if attention_mask is not None else None
1815
+ token_type_ids = token_type_ids.view(
1816
+ -1, token_type_ids.size(-1)) if token_type_ids is not None else None
1817
+ position_ids = position_ids.view(-1, position_ids.size(-1)
1818
+ ) if position_ids is not None else None
1819
+ inputs_embeds = (
1820
+ inputs_embeds.view(-1, inputs_embeds.size(-2),
1821
+ inputs_embeds.size(-1))
1822
+ if inputs_embeds is not None
1823
+ else None
1824
+ )
1825
+
1826
+ outputs = self.bert(
1827
+ input_ids,
1828
+ attention_mask=attention_mask,
1829
+ token_type_ids=token_type_ids,
1830
+ position_ids=position_ids,
1831
+ head_mask=head_mask,
1832
+ inputs_embeds=inputs_embeds,
1833
+ output_attentions=output_attentions,
1834
+ output_hidden_states=output_hidden_states,
1835
+ return_dict=return_dict,
1836
+ )
1837
+
1838
+ pooled_output = outputs[1]
1839
+
1840
+ pooled_output = self.dropout(pooled_output)
1841
+ logits = self.classifier(pooled_output)
1842
+ reshaped_logits = logits.view(-1, num_choices)
1843
+
1844
+ loss = None
1845
+ if labels is not None:
1846
+ loss_fct = CrossEntropyLoss()
1847
+ loss = loss_fct(reshaped_logits, labels)
1848
+
1849
+ if not return_dict:
1850
+ output = (reshaped_logits,) + outputs[2:]
1851
+ return ((loss,) + output) if loss is not None else output
1852
+
1853
+ return MultipleChoiceModelOutput(
1854
+ loss=loss,
1855
+ logits=reshaped_logits,
1856
+ hidden_states=outputs.hidden_states,
1857
+ attentions=outputs.attentions,
1858
+ )
1859
+
1860
+
1861
+ @add_start_docstrings(
1862
+ """
1863
+ Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1864
+ Named-Entity-Recognition (NER) tasks.
1865
+ """,
1866
+ BERT_START_DOCSTRING,
1867
+ )
1868
+ class BertForTokenClassification(BertPreTrainedModel):
1869
+
1870
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1871
+
1872
+ def __init__(self, config):
1873
+ super().__init__(config)
1874
+ self.num_labels = config.num_labels
1875
+
1876
+ self.bert = BertModel(config, add_pooling_layer=False)
1877
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1878
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1879
+
1880
+ self.init_weights()
1881
+
1882
+ def forward(
1883
+ self,
1884
+ input_ids=None,
1885
+ attention_mask=None,
1886
+ token_type_ids=None,
1887
+ position_ids=None,
1888
+ head_mask=None,
1889
+ inputs_embeds=None,
1890
+ labels=None,
1891
+ output_attentions=None,
1892
+ output_hidden_states=None,
1893
+ return_dict=None,
1894
+ ):
1895
+ r"""
1896
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1897
+ Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1898
+ 1]``.
1899
+ """
1900
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1901
+
1902
+ outputs = self.bert(
1903
+ input_ids,
1904
+ attention_mask=attention_mask,
1905
+ token_type_ids=token_type_ids,
1906
+ position_ids=position_ids,
1907
+ head_mask=head_mask,
1908
+ inputs_embeds=inputs_embeds,
1909
+ output_attentions=output_attentions,
1910
+ output_hidden_states=output_hidden_states,
1911
+ return_dict=return_dict,
1912
+ )
1913
+
1914
+ sequence_output = outputs[0]
1915
+
1916
+ sequence_output = self.dropout(sequence_output)
1917
+ logits = self.classifier(sequence_output)
1918
+
1919
+ loss = None
1920
+ if labels is not None:
1921
+ loss_fct = CrossEntropyLoss()
1922
+ # Only keep active parts of the loss
1923
+ if attention_mask is not None:
1924
+ active_loss = attention_mask.view(-1) == 1
1925
+ active_logits = logits.view(-1, self.num_labels)
1926
+ active_labels = torch.where(
1927
+ active_loss, labels.view(-1), torch.tensor(
1928
+ loss_fct.ignore_index).type_as(labels)
1929
+ )
1930
+ loss = loss_fct(active_logits, active_labels)
1931
+ else:
1932
+ loss = loss_fct(
1933
+ logits.view(-1, self.num_labels), labels.view(-1))
1934
+
1935
+ if not return_dict:
1936
+ output = (logits,) + outputs[2:]
1937
+ return ((loss,) + output) if loss is not None else output
1938
+
1939
+ return TokenClassifierOutput(
1940
+ loss=loss,
1941
+ logits=logits,
1942
+ hidden_states=outputs.hidden_states,
1943
+ attentions=outputs.attentions,
1944
+ )
1945
+
1946
+
1947
+ @add_start_docstrings(
1948
+ """
1949
+ Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1950
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1951
+ """,
1952
+ BERT_START_DOCSTRING,
1953
+ )
1954
+ class BertForQuestionAnswering(BertPreTrainedModel):
1955
+
1956
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1957
+
1958
+ def __init__(self, config):
1959
+ super().__init__(config)
1960
+ self.num_labels = config.num_labels
1961
+
1962
+ self.bert = BertModel(config, add_pooling_layer=False)
1963
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1964
+
1965
+ self.init_weights()
1966
+
1967
+ def forward(
1968
+ self,
1969
+ input_ids=None,
1970
+ attention_mask=None,
1971
+ token_type_ids=None,
1972
+ position_ids=None,
1973
+ head_mask=None,
1974
+ inputs_embeds=None,
1975
+ start_positions=None,
1976
+ end_positions=None,
1977
+ output_attentions=None,
1978
+ output_hidden_states=None,
1979
+ return_dict=None,
1980
+ ):
1981
+ r"""
1982
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1983
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1984
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1985
+ sequence are not taken into account for computing the loss.
1986
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1987
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1988
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1989
+ sequence are not taken into account for computing the loss.
1990
+ """
1991
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1992
+
1993
+ outputs = self.bert(
1994
+ input_ids,
1995
+ attention_mask=attention_mask,
1996
+ token_type_ids=token_type_ids,
1997
+ position_ids=position_ids,
1998
+ head_mask=head_mask,
1999
+ inputs_embeds=inputs_embeds,
2000
+ output_attentions=output_attentions,
2001
+ output_hidden_states=output_hidden_states,
2002
+ return_dict=return_dict,
2003
+ )
2004
+
2005
+ sequence_output = outputs[0]
2006
+
2007
+ logits = self.qa_outputs(sequence_output)
2008
+ start_logits, end_logits = logits.split(1, dim=-1)
2009
+ start_logits = start_logits.squeeze(-1)
2010
+ end_logits = end_logits.squeeze(-1)
2011
+
2012
+ total_loss = None
2013
+ if start_positions is not None and end_positions is not None:
2014
+ # If we are on multi-GPU, split add a dimension
2015
+ if len(start_positions.size()) > 1:
2016
+ start_positions = start_positions.squeeze(-1)
2017
+ if len(end_positions.size()) > 1:
2018
+ end_positions = end_positions.squeeze(-1)
2019
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
2020
+ ignored_index = start_logits.size(1)
2021
+ start_positions.clamp_(0, ignored_index)
2022
+ end_positions.clamp_(0, ignored_index)
2023
+
2024
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
2025
+ start_loss = loss_fct(start_logits, start_positions)
2026
+ end_loss = loss_fct(end_logits, end_positions)
2027
+ total_loss = (start_loss + end_loss) / 2
2028
+
2029
+ if not return_dict:
2030
+ output = (start_logits, end_logits) + outputs[2:]
2031
+ return ((total_loss,) + output) if total_loss is not None else output
2032
+
2033
+ return QuestionAnsweringModelOutput(
2034
+ loss=total_loss,
2035
+ start_logits=start_logits,
2036
+ end_logits=end_logits,
2037
+ hidden_states=outputs.hidden_states,
2038
+ attentions=outputs.attentions,
2039
+ )
svitt/tokenization_bert.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Bert."""
16
+
17
+
18
+ import collections
19
+ import os
20
+ import unicodedata
21
+ from typing import List, Optional, Tuple
22
+
23
+ from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
24
+ from transformers.utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {
32
+ "vocab_file": {
33
+ "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
34
+ "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
35
+ "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
36
+ "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
37
+ "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt",
38
+ "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
39
+ "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
40
+ "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt",
41
+ "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt",
42
+ "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt",
43
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
44
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
45
+ "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt",
46
+ "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
47
+ "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt",
48
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt",
49
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt",
50
+ "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt",
51
+ }
52
+ }
53
+
54
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
55
+ "bert-base-uncased": 512,
56
+ "bert-large-uncased": 512,
57
+ "bert-base-cased": 512,
58
+ "bert-large-cased": 512,
59
+ "bert-base-multilingual-uncased": 512,
60
+ "bert-base-multilingual-cased": 512,
61
+ "bert-base-chinese": 512,
62
+ "bert-base-german-cased": 512,
63
+ "bert-large-uncased-whole-word-masking": 512,
64
+ "bert-large-cased-whole-word-masking": 512,
65
+ "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
66
+ "bert-large-cased-whole-word-masking-finetuned-squad": 512,
67
+ "bert-base-cased-finetuned-mrpc": 512,
68
+ "bert-base-german-dbmdz-cased": 512,
69
+ "bert-base-german-dbmdz-uncased": 512,
70
+ "TurkuNLP/bert-base-finnish-cased-v1": 512,
71
+ "TurkuNLP/bert-base-finnish-uncased-v1": 512,
72
+ "wietsedv/bert-base-dutch-cased": 512,
73
+ }
74
+
75
+ PRETRAINED_INIT_CONFIGURATION = {
76
+ "bert-base-uncased": {"do_lower_case": True},
77
+ "bert-large-uncased": {"do_lower_case": True},
78
+ "bert-base-cased": {"do_lower_case": False},
79
+ "bert-large-cased": {"do_lower_case": False},
80
+ "bert-base-multilingual-uncased": {"do_lower_case": True},
81
+ "bert-base-multilingual-cased": {"do_lower_case": False},
82
+ "bert-base-chinese": {"do_lower_case": False},
83
+ "bert-base-german-cased": {"do_lower_case": False},
84
+ "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
85
+ "bert-large-cased-whole-word-masking": {"do_lower_case": False},
86
+ "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
87
+ "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
88
+ "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
89
+ "bert-base-german-dbmdz-cased": {"do_lower_case": False},
90
+ "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
91
+ "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
92
+ "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
93
+ "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
94
+ }
95
+
96
+
97
+ def load_vocab(vocab_file):
98
+ """Loads a vocabulary file into a dictionary."""
99
+ vocab = collections.OrderedDict()
100
+ with open(vocab_file, "r", encoding="utf-8") as reader:
101
+ tokens = reader.readlines()
102
+ for index, token in enumerate(tokens):
103
+ token = token.rstrip("\n")
104
+ vocab[token] = index
105
+ return vocab
106
+
107
+
108
+ def whitespace_tokenize(text):
109
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
110
+ text = text.strip()
111
+ if not text:
112
+ return []
113
+ tokens = text.split()
114
+ return tokens
115
+
116
+
117
+ class BertTokenizer(PreTrainedTokenizer):
118
+ r"""
119
+ Construct a BERT tokenizer. Based on WordPiece.
120
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
121
+ Users should refer to this superclass for more information regarding those methods.
122
+ Args:
123
+ vocab_file (:obj:`str`):
124
+ File containing the vocabulary.
125
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
126
+ Whether or not to lowercase the input when tokenizing.
127
+ do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
128
+ Whether or not to do basic tokenization before WordPiece.
129
+ never_split (:obj:`Iterable`, `optional`):
130
+ Collection of tokens which will never be split during tokenization. Only has an effect when
131
+ :obj:`do_basic_tokenize=True`
132
+ unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
133
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
134
+ token instead.
135
+ sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
136
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
137
+ sequence classification or for a text and a question for question answering. It is also used as the last
138
+ token of a sequence built with special tokens.
139
+ pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
140
+ The token used for padding, for example when batching sequences of different lengths.
141
+ cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
142
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
143
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
144
+ mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
145
+ The token used for masking values. This is the token used when training this model with masked language
146
+ modeling. This is the token which the model will try to predict.
147
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
148
+ Whether or not to tokenize Chinese characters.
149
+ This should likely be deactivated for Japanese (see this `issue
150
+ <https://github.com/huggingface/transformers/issues/328>`__).
151
+ strip_accents: (:obj:`bool`, `optional`):
152
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
153
+ value for :obj:`lowercase` (as in the original BERT).
154
+ """
155
+
156
+ vocab_files_names = VOCAB_FILES_NAMES
157
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
158
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
159
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
160
+
161
+ def __init__(
162
+ self,
163
+ vocab_file,
164
+ do_lower_case=True,
165
+ do_basic_tokenize=True,
166
+ never_split=None,
167
+ unk_token="[UNK]",
168
+ sep_token="[SEP]",
169
+ pad_token="[PAD]",
170
+ cls_token="[CLS]",
171
+ mask_token="[MASK]",
172
+ tokenize_chinese_chars=True,
173
+ strip_accents=None,
174
+ **kwargs
175
+ ):
176
+ super().__init__(
177
+ do_lower_case=do_lower_case,
178
+ do_basic_tokenize=do_basic_tokenize,
179
+ never_split=never_split,
180
+ unk_token=unk_token,
181
+ sep_token=sep_token,
182
+ pad_token=pad_token,
183
+ cls_token=cls_token,
184
+ mask_token=mask_token,
185
+ tokenize_chinese_chars=tokenize_chinese_chars,
186
+ strip_accents=strip_accents,
187
+ **kwargs,
188
+ )
189
+
190
+ if not os.path.isfile(vocab_file):
191
+ raise ValueError(
192
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
193
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
194
+ vocab_file)
195
+ )
196
+ self.vocab = load_vocab(vocab_file)
197
+ self.ids_to_tokens = collections.OrderedDict(
198
+ [(ids, tok) for tok, ids in self.vocab.items()])
199
+ self.do_basic_tokenize = do_basic_tokenize
200
+ if do_basic_tokenize:
201
+ self.basic_tokenizer = BasicTokenizer(
202
+ do_lower_case=do_lower_case,
203
+ never_split=never_split,
204
+ tokenize_chinese_chars=tokenize_chinese_chars,
205
+ strip_accents=strip_accents,
206
+ )
207
+ self.wordpiece_tokenizer = WordpieceTokenizer(
208
+ vocab=self.vocab, unk_token=self.unk_token)
209
+
210
+ @property
211
+ def do_lower_case(self):
212
+ return self.basic_tokenizer.do_lower_case
213
+
214
+ @property
215
+ def vocab_size(self):
216
+ return len(self.vocab)
217
+
218
+ def get_vocab(self):
219
+ return dict(self.vocab, **self.added_tokens_encoder)
220
+
221
+ def _tokenize(self, text):
222
+ split_tokens = []
223
+ if self.do_basic_tokenize:
224
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
225
+
226
+ # If the token is part of the never_split set
227
+ if token in self.basic_tokenizer.never_split:
228
+ split_tokens.append(token)
229
+ else:
230
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
231
+ else:
232
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
233
+ return split_tokens
234
+
235
+ def _convert_token_to_id(self, token):
236
+ """ Converts a token (str) in an id using the vocab. """
237
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
238
+
239
+ def _convert_id_to_token(self, index):
240
+ """Converts an index (integer) in a token (str) using the vocab."""
241
+ return self.ids_to_tokens.get(index, self.unk_token)
242
+
243
+ def convert_tokens_to_string(self, tokens):
244
+ """ Converts a sequence of tokens (string) in a single string. """
245
+ out_string = " ".join(tokens).replace(" ##", "").strip()
246
+ return out_string
247
+
248
+ def build_inputs_with_special_tokens(
249
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
250
+ ) -> List[int]:
251
+ """
252
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
253
+ adding special tokens. A BERT sequence has the following format:
254
+ - single sequence: ``[CLS] X ``
255
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
256
+ Args:
257
+ token_ids_0 (:obj:`List[int]`):
258
+ List of IDs to which the special tokens will be added.
259
+ token_ids_1 (:obj:`List[int]`, `optional`):
260
+ Optional second list of IDs for sequence pairs.
261
+ Returns:
262
+ :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
263
+ """
264
+ if token_ids_1 is None:
265
+ return [self.cls_token_id] + token_ids_0
266
+ cls = [self.cls_token_id]
267
+ sep = [self.sep_token_id]
268
+ return cls + token_ids_0 + sep + token_ids_1 + sep
269
+
270
+ def get_special_tokens_mask(
271
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
272
+ ) -> List[int]:
273
+ """
274
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
275
+ special tokens using the tokenizer ``prepare_for_model`` method.
276
+ Args:
277
+ token_ids_0 (:obj:`List[int]`):
278
+ List of IDs.
279
+ token_ids_1 (:obj:`List[int]`, `optional`):
280
+ Optional second list of IDs for sequence pairs.
281
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
282
+ Whether or not the token list is already formatted with special tokens for the model.
283
+ Returns:
284
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
285
+ """
286
+
287
+ if already_has_special_tokens:
288
+ if token_ids_1 is not None:
289
+ raise ValueError(
290
+ "You should not supply a second sequence if the provided sequence of "
291
+ "ids is already formatted with special tokens for the model."
292
+ )
293
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
294
+
295
+ if token_ids_1 is not None:
296
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
297
+ return [1] + ([0] * len(token_ids_0)) + [1]
298
+
299
+ def create_token_type_ids_from_sequences(
300
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
301
+ ) -> List[int]:
302
+ """
303
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
304
+ pair mask has the following format:
305
+ ::
306
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
307
+ | first sequence | second sequence |
308
+ If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
309
+ Args:
310
+ token_ids_0 (:obj:`List[int]`):
311
+ List of IDs.
312
+ token_ids_1 (:obj:`List[int]`, `optional`):
313
+ Optional second list of IDs for sequence pairs.
314
+ Returns:
315
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
316
+ sequence(s).
317
+ """
318
+ sep = [self.sep_token_id]
319
+ cls = [self.cls_token_id]
320
+ if token_ids_1 is None:
321
+ return len(cls + token_ids_0 + sep) * [0]
322
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
323
+
324
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
325
+ index = 0
326
+ if os.path.isdir(save_directory):
327
+ vocab_file = os.path.join(
328
+ save_directory, (filename_prefix + "-" if filename_prefix else "") +
329
+ VOCAB_FILES_NAMES["vocab_file"]
330
+ )
331
+ else:
332
+ vocab_file = (filename_prefix +
333
+ "-" if filename_prefix else "") + save_directory
334
+ with open(vocab_file, "w", encoding="utf-8") as writer:
335
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
336
+ if index != token_index:
337
+ logger.warning(
338
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
339
+ " Please check that the vocabulary is not corrupted!".format(
340
+ vocab_file)
341
+ )
342
+ index = token_index
343
+ writer.write(token + "\n")
344
+ index += 1
345
+ return (vocab_file,)
346
+
347
+
348
+ class BasicTokenizer(object):
349
+ """
350
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
351
+ Args:
352
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
353
+ Whether or not to lowercase the input when tokenizing.
354
+ never_split (:obj:`Iterable`, `optional`):
355
+ Collection of tokens which will never be split during tokenization. Only has an effect when
356
+ :obj:`do_basic_tokenize=True`
357
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
358
+ Whether or not to tokenize Chinese characters.
359
+ This should likely be deactivated for Japanese (see this `issue
360
+ <https://github.com/huggingface/transformers/issues/328>`__).
361
+ strip_accents: (:obj:`bool`, `optional`):
362
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
363
+ value for :obj:`lowercase` (as in the original BERT).
364
+ """
365
+
366
+ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
367
+ if never_split is None:
368
+ never_split = []
369
+ self.do_lower_case = do_lower_case
370
+ self.never_split = set(never_split)
371
+ self.tokenize_chinese_chars = tokenize_chinese_chars
372
+ self.strip_accents = strip_accents
373
+
374
+ def tokenize(self, text, never_split=None):
375
+ """
376
+ Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
377
+ WordPieceTokenizer.
378
+ Args:
379
+ **never_split**: (`optional`) list of str
380
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
381
+ :func:`PreTrainedTokenizer.tokenize`) List of token not to split.
382
+ """
383
+ # union() returns a new set by concatenating the two sets.
384
+ never_split = self.never_split.union(
385
+ set(never_split)) if never_split else self.never_split
386
+ text = self._clean_text(text)
387
+
388
+ # This was added on November 1st, 2018 for the multilingual and Chinese
389
+ # models. This is also applied to the English models now, but it doesn't
390
+ # matter since the English models were not trained on any Chinese data
391
+ # and generally don't have any Chinese data in them (there are Chinese
392
+ # characters in the vocabulary because Wikipedia does have some Chinese
393
+ # words in the English Wikipedia.).
394
+ if self.tokenize_chinese_chars:
395
+ text = self._tokenize_chinese_chars(text)
396
+ orig_tokens = whitespace_tokenize(text)
397
+ split_tokens = []
398
+ for token in orig_tokens:
399
+ if token not in never_split:
400
+ if self.do_lower_case:
401
+ token = token.lower()
402
+ if self.strip_accents is not False:
403
+ token = self._run_strip_accents(token)
404
+ elif self.strip_accents:
405
+ token = self._run_strip_accents(token)
406
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
407
+
408
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
409
+ return output_tokens
410
+
411
+ def _run_strip_accents(self, text):
412
+ """Strips accents from a piece of text."""
413
+ text = unicodedata.normalize("NFD", text)
414
+ output = []
415
+ for char in text:
416
+ cat = unicodedata.category(char)
417
+ if cat == "Mn":
418
+ continue
419
+ output.append(char)
420
+ return "".join(output)
421
+
422
+ def _run_split_on_punc(self, text, never_split=None):
423
+ """Splits punctuation on a piece of text."""
424
+ if never_split is not None and text in never_split:
425
+ return [text]
426
+ chars = list(text)
427
+ i = 0
428
+ start_new_word = True
429
+ output = []
430
+ while i < len(chars):
431
+ char = chars[i]
432
+ if _is_punctuation(char):
433
+ output.append([char])
434
+ start_new_word = True
435
+ else:
436
+ if start_new_word:
437
+ output.append([])
438
+ start_new_word = False
439
+ output[-1].append(char)
440
+ i += 1
441
+
442
+ return ["".join(x) for x in output]
443
+
444
+ def _tokenize_chinese_chars(self, text):
445
+ """Adds whitespace around any CJK character."""
446
+ output = []
447
+ for char in text:
448
+ cp = ord(char)
449
+ if self._is_chinese_char(cp):
450
+ output.append(" ")
451
+ output.append(char)
452
+ output.append(" ")
453
+ else:
454
+ output.append(char)
455
+ return "".join(output)
456
+
457
+ def _is_chinese_char(self, cp):
458
+ """Checks whether CP is the codepoint of a CJK character."""
459
+ # This defines a "chinese character" as anything in the CJK Unicode block:
460
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
461
+ #
462
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
463
+ # despite its name. The modern Korean Hangul alphabet is a different block,
464
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
465
+ # space-separated words, so they are not treated specially and handled
466
+ # like the all of the other languages.
467
+ if (
468
+ (cp >= 0x4E00 and cp <= 0x9FFF)
469
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
470
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
471
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
472
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
473
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
474
+ or (cp >= 0xF900 and cp <= 0xFAFF)
475
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
476
+ ): #
477
+ return True
478
+
479
+ return False
480
+
481
+ def _clean_text(self, text):
482
+ """Performs invalid character removal and whitespace cleanup on text."""
483
+ output = []
484
+ for char in text:
485
+ cp = ord(char)
486
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
487
+ continue
488
+ if _is_whitespace(char):
489
+ output.append(" ")
490
+ else:
491
+ output.append(char)
492
+ return "".join(output)
493
+
494
+
495
+ class WordpieceTokenizer(object):
496
+ """Runs WordPiece tokenization."""
497
+
498
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
499
+ self.vocab = vocab
500
+ self.unk_token = unk_token
501
+ self.max_input_chars_per_word = max_input_chars_per_word
502
+
503
+ def tokenize(self, text):
504
+ """
505
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
506
+ tokenization using the given vocabulary.
507
+ For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
508
+ Args:
509
+ text: A single token or whitespace separated tokens. This should have
510
+ already been passed through `BasicTokenizer`.
511
+ Returns:
512
+ A list of wordpiece tokens.
513
+ """
514
+
515
+ output_tokens = []
516
+ for token in whitespace_tokenize(text):
517
+ chars = list(token)
518
+ if len(chars) > self.max_input_chars_per_word:
519
+ output_tokens.append(self.unk_token)
520
+ continue
521
+
522
+ is_bad = False
523
+ start = 0
524
+ sub_tokens = []
525
+ while start < len(chars):
526
+ end = len(chars)
527
+ cur_substr = None
528
+ while start < end:
529
+ substr = "".join(chars[start:end])
530
+ if start > 0:
531
+ substr = "##" + substr
532
+ if substr in self.vocab:
533
+ cur_substr = substr
534
+ break
535
+ end -= 1
536
+ if cur_substr is None:
537
+ is_bad = True
538
+ break
539
+ sub_tokens.append(cur_substr)
540
+ start = end
541
+
542
+ if is_bad:
543
+ output_tokens.append(self.unk_token)
544
+ else:
545
+ output_tokens.extend(sub_tokens)
546
+ return output_tokens
svitt/utils.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from scipy import interpolate
5
+ import numpy as np
6
+ from einops import rearrange, repeat
7
+
8
+
9
+ def _init_transformer_weights(module, initializer_range=0.02):
10
+ """Initialize the weights. Copied from transformers ViT/Bert model init"""
11
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
12
+ # Slightly different from the TF version which uses truncated_normal for initialization
13
+ # cf https://github.com/pytorch/pytorch/pull/5617
14
+ module.weight.data.normal_(mean=0.0, std=initializer_range)
15
+ if module.bias is not None:
16
+ module.bias.data.zero_()
17
+ elif isinstance(module, nn.Embedding):
18
+ module.weight.data.normal_(mean=0.0, std=initializer_range)
19
+ if module.padding_idx is not None:
20
+ module.weight.data[module.padding_idx].zero_()
21
+ elif isinstance(module, nn.LayerNorm):
22
+ module.bias.data.zero_()
23
+ module.weight.data.fill_(1.0)
24
+
25
+
26
+ def interpolate_pos_embed(pos_embed_old, pos_embed_new, num_patches_new):
27
+ """
28
+ Args:
29
+ pos_embed_old: (1, L_old, d), pre-trained
30
+ pos_embed_new: (1, L_new, d), newly initialized, to be replaced by interpolated weights
31
+ num_patches_new:
32
+ """
33
+ # interpolate position embedding
34
+ embedding_size = pos_embed_old.shape[-1]
35
+ num_extra_tokens = pos_embed_new.shape[-2] - num_patches_new
36
+ # height (== width) for the checkpoint position embedding
37
+ orig_size = int((pos_embed_old.shape[-2] - num_extra_tokens) ** 0.5)
38
+ # height (== width) for the new position embedding
39
+ new_size = int(num_patches_new ** 0.5)
40
+
41
+ if orig_size != new_size:
42
+ # class_token and dist_token are kept unchanged
43
+ # the extra tokens seems always at the beginning of the position embedding
44
+ extra_tokens = pos_embed_old[:, :num_extra_tokens]
45
+ # only the position tokens are interpolated
46
+ pos_tokens = pos_embed_old[:, num_extra_tokens:]
47
+ pos_tokens = pos_tokens.reshape(
48
+ -1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
49
+ pos_tokens = torch.nn.functional.interpolate(
50
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
51
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
52
+ interpolated_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
53
+ return interpolated_pos_embed
54
+ else:
55
+ return pos_embed_old
56
+
57
+
58
+ def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, patch_shape_new):
59
+ """
60
+ Args:
61
+ state_dict_old: loaded state dict
62
+ state_dict_new: state dict for model with new image size
63
+ patch_shape_new: new model patch_shape
64
+ ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py
65
+ """
66
+ all_keys = list(state_dict_old.keys())
67
+ for key in all_keys:
68
+ if "relative_position_index" in key:
69
+ state_dict_old.pop(key)
70
+
71
+ if "relative_position_bias_table" in key:
72
+ rel_pos_bias = state_dict_old[key]
73
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
74
+ dst_num_pos, _ = state_dict_new[key].size()
75
+ dst_patch_shape = patch_shape_new
76
+ if dst_patch_shape[0] != dst_patch_shape[1]:
77
+ raise NotImplementedError()
78
+ num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
79
+ src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
80
+ dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
81
+ if src_size != dst_size:
82
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
83
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
84
+
85
+ def geometric_progression(a, r, n):
86
+ return a * (1.0 - r ** n) / (1.0 - r)
87
+
88
+ left, right = 1.01, 1.5
89
+ while right - left > 1e-6:
90
+ q = (left + right) / 2.0
91
+ gp = geometric_progression(1, q, src_size // 2)
92
+ if gp > dst_size // 2:
93
+ right = q
94
+ else:
95
+ left = q
96
+
97
+ # if q > 1.090307:
98
+ # q = 1.090307
99
+
100
+ dis = []
101
+ cur = 1
102
+ for i in range(src_size // 2):
103
+ dis.append(cur)
104
+ cur += q ** (i + 1)
105
+
106
+ r_ids = [-_ for _ in reversed(dis)]
107
+
108
+ x = r_ids + [0] + dis
109
+ y = r_ids + [0] + dis
110
+
111
+ t = dst_size // 2.0
112
+ dx = np.arange(-t, t + 0.1, 1.0)
113
+ dy = np.arange(-t, t + 0.1, 1.0)
114
+
115
+ all_rel_pos_bias = []
116
+
117
+ for i in range(num_attn_heads):
118
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
119
+ f = interpolate.interp2d(x, y, z, kind='cubic')
120
+ all_rel_pos_bias.append(
121
+ torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
122
+
123
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
124
+
125
+ new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
126
+ state_dict_old[key] = new_rel_pos_bias
127
+ return state_dict_old
128
+
129
+
130
+ def interpolate_pos_relative_bias_beit_3d(state_dict_old, state_dict_new, patch_shape_new, src_t_size=1):
131
+ """
132
+ Args:
133
+ state_dict_old: loaded state dict
134
+ state_dict_new: state dict for model with new image size
135
+ patch_shape_new: new model patch_shape
136
+ ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py
137
+ """
138
+ all_keys = list(state_dict_old.keys())
139
+ for key in all_keys:
140
+ if "relative_position_index" in key:
141
+ state_dict_old.pop(key)
142
+
143
+ if "relative_position_bias_table" in key:
144
+ src_num_pos, num_attn_heads = state_dict_old[key].size()
145
+ dst_num_pos, _ = state_dict_new[key].size()
146
+ if src_num_pos == dst_num_pos:
147
+ continue
148
+
149
+ num_extra_tokens = dst_num_pos - np.prod([w * 2 - 1 for w in patch_shape_new])
150
+
151
+ src_s_size = int((src_num_pos - num_extra_tokens) / src_t_size)
152
+ src_size = int(src_s_size ** 0.5)
153
+ dst_size = patch_shape_new[-1] * 2 - 1
154
+
155
+ if src_size != dst_size:
156
+ # Spatial interpolation
157
+ rel_pos_bias = state_dict_old[key]
158
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
159
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
160
+
161
+ def geometric_progression(a, r, n):
162
+ return a * (1.0 - r ** n) / (1.0 - r)
163
+
164
+ left, right = 1.01, 1.5
165
+ while right - left > 1e-6:
166
+ q = (left + right) / 2.0
167
+ gp = geometric_progression(1, q, src_size // 2)
168
+ if gp > dst_size // 2:
169
+ right = q
170
+ else:
171
+ left = q
172
+
173
+ # if q > 1.090307:
174
+ # q = 1.090307
175
+
176
+ dis = []
177
+ cur = 1
178
+ for i in range(src_size // 2):
179
+ dis.append(cur)
180
+ cur += q ** (i + 1)
181
+
182
+ r_ids = [-_ for _ in reversed(dis)]
183
+
184
+ x = r_ids + [0] + dis
185
+ y = r_ids + [0] + dis
186
+
187
+ t = dst_size // 2.0
188
+ dx = np.arange(-t, t + 0.1, 1.0)
189
+ dy = np.arange(-t, t + 0.1, 1.0)
190
+
191
+ all_rel_pos_bias = []
192
+
193
+ for i in range(num_attn_heads):
194
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
195
+ f = interpolate.interp2d(x, y, z, kind='cubic')
196
+ all_rel_pos_bias.append(
197
+ torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
198
+
199
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
200
+
201
+ new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
202
+ state_dict_old[key] = new_rel_pos_bias
203
+
204
+ dst_t_size = patch_shape_new[0] * 2 - 1
205
+ if src_t_size != dst_t_size:
206
+ # Temporal interpolation
207
+ rel_pos_bias = state_dict_old[key]
208
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
209
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
210
+
211
+ if src_t_size == 1:
212
+ rel_pos_bias = repeat(rel_pos_bias, 's d -> (t s) d', t=dst_t_size)
213
+ else:
214
+ rel_pos_bias = rearrange(rel_pos_bias, '(t s) d -> s d t', t=src_t_size)
215
+ rel_pos_bias = F.interpolate(rel_pos_bias, dst_t_size, mode='nearest')
216
+ rel_pos_bias = rearrange(rel_pos_bias, 's d t -> (t s) d')
217
+ new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
218
+ state_dict_old[key] = new_rel_pos_bias
219
+
220
+ return state_dict_old
221
+
222
+
223
+ def tile(x, dim, n_tile):
224
+ init_dim = x.size(dim)
225
+ repeat_idx = [1] * x.dim()
226
+ repeat_idx[dim] = n_tile
227
+ x = x.repeat(*repeat_idx)
228
+ order_index = torch.LongTensor(np.concatenate(
229
+ [init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
230
+ return torch.index_select(x, dim, order_index.to(x.device))
231
+
232
+
233
+ def mask_logits(target, mask):
234
+ return target * mask + (1 - mask) * (-1e10)
235
+